mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Generation / FIX: Fix multi-device generation (#30746)
* attempt to fix multi-device generation * fix * final fix * final fix * fix * fix * fix * fix * add joao suggestion * fix
This commit is contained in:
parent
a0779b9e19
commit
f823fec53e
@ -476,6 +476,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
)
|
)
|
||||||
@ -1340,7 +1341,10 @@ class GenerationMixin:
|
|||||||
return self._static_cache
|
return self._static_cache
|
||||||
|
|
||||||
def _prepare_special_tokens(
|
def _prepare_special_tokens(
|
||||||
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
|
self,
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
kwargs_has_attention_mask: Optional[bool] = None,
|
||||||
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
||||||
@ -1352,15 +1356,18 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Convert special tokens to tensors (if they exist)
|
# Convert special tokens to tensors (if they exist)
|
||||||
def _tensor_or_none(token):
|
def _tensor_or_none(token, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = self.device
|
||||||
|
|
||||||
if token is None or isinstance(token, torch.Tensor):
|
if token is None or isinstance(token, torch.Tensor):
|
||||||
return token
|
return token
|
||||||
return torch.tensor(token, device=self.device, dtype=torch.long)
|
return torch.tensor(token, device=device, dtype=torch.long)
|
||||||
|
|
||||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id)
|
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
|
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
|
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
|
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||||
|
|
||||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||||
@ -1511,7 +1518,6 @@ class GenerationMixin:
|
|||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
|
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
@ -1519,6 +1525,9 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
batch_size = inputs_tensor.shape[0]
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
device = inputs_tensor.device
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||||
|
|
||||||
# decoder-only models must use left-padding for batched generation.
|
# decoder-only models must use left-padding for batched generation.
|
||||||
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
||||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
Loading…
Reference in New Issue
Block a user