fix encoder hook (#25735)

* fix encoder hook

* style
This commit is contained in:
Marc Sun 2023-08-25 09:36:41 -04:00 committed by GitHub
parent dd8b7d28ae
commit 35c570c80e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -33,7 +33,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ExplicitEnum, ModelOutput, logging
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
@ -80,6 +80,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
@ -631,8 +634,11 @@ class GenerationMixin:
encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
# 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]