Generation / FIX: Fix multi-device generation (#30746)

* attempt to fix multi-device generation

* fix

* final fix

* final fix

* fix

* fix

* fix

* fix

* add joao suggestion

* fix
This commit is contained in:
Younes Belkada 2024-05-13 14:35:45 +02:00 committed by GitHub
parent a0779b9e19
commit f823fec53e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -476,6 +476,7 @@ class GenerationMixin:
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
@ -1340,7 +1341,10 @@ class GenerationMixin:
return self._static_cache
def _prepare_special_tokens(
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
self,
generation_config: GenerationConfig,
kwargs_has_attention_mask: Optional[bool] = None,
device: Optional[Union[torch.device, str]] = None,
):
"""
Prepares the special tokens for generation, overwriting the generation config with their processed versions
@ -1352,15 +1356,18 @@ class GenerationMixin:
"""
# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token):
def _tensor_or_none(token, device=None):
if device is None:
device = self.device
if token is None or isinstance(token, torch.Tensor):
return token
return torch.tensor(token, device=self.device, dtype=torch.long)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_id = _tensor_or_none(generation_config.bos_token_id)
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
@ -1511,7 +1518,6 @@ class GenerationMixin:
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
@ -1519,6 +1525,9 @@ class GenerationMixin:
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`