mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Generation / FIX: Fix multi-device generation (#30746)
* attempt to fix multi-device generation * fix * final fix * final fix * fix * fix * fix * fix * add joao suggestion * fix
This commit is contained in:
parent
a0779b9e19
commit
f823fec53e
@ -476,6 +476,7 @@ class GenerationMixin:
|
||||
)
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||
|
||||
attention_mask = (
|
||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
@ -1340,7 +1341,10 @@ class GenerationMixin:
|
||||
return self._static_cache
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
kwargs_has_attention_mask: Optional[bool] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
"""
|
||||
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
||||
@ -1352,15 +1356,18 @@ class GenerationMixin:
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors (if they exist)
|
||||
def _tensor_or_none(token):
|
||||
def _tensor_or_none(token, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
if token is None or isinstance(token, torch.Tensor):
|
||||
return token
|
||||
return torch.tensor(token, device=self.device, dtype=torch.long)
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id)
|
||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
|
||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
|
||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
|
||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
@ -1511,7 +1518,6 @@ class GenerationMixin:
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
|
||||
|
||||
# 3. Define model inputs
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
@ -1519,6 +1525,9 @@ class GenerationMixin:
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
device = inputs_tensor.device
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||
|
||||
# decoder-only models must use left-padding for batched generation.
|
||||
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||
|
Loading…
Reference in New Issue
Block a user