mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
dd8b7d28ae
commit
35c570c80e
@ -33,7 +33,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..utils import ExplicitEnum, ModelOutput, logging
|
||||
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .configuration_utils import GenerationConfig
|
||||
@ -80,6 +80,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
||||
|
||||
|
||||
@dataclass
|
||||
class GreedySearchDecoderOnlyOutput(ModelOutput):
|
||||
@ -631,8 +634,11 @@ class GenerationMixin:
|
||||
encoder = self.get_encoder()
|
||||
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
|
||||
# as the inputs.
|
||||
if hasattr(encoder, "_hf_hook"):
|
||||
encoder._hf_hook.io_same_device = True
|
||||
if hasattr(self, "hf_device_map"):
|
||||
if hasattr(encoder, "_hf_hook"):
|
||||
encoder._hf_hook.io_same_device = True
|
||||
else:
|
||||
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
|
||||
|
||||
# 2. Prepare encoder args and encoder kwargs from model kwargs.
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
|
Loading…
Reference in New Issue
Block a user