mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Fix default attention mask of generate in MoshiForConditionalGeneration (#36171)
This commit is contained in:
parent
27d1707586
commit
e18f233f6c
@ -2099,6 +2099,31 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
||||
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
generation_config: GenerationConfig,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> torch.LongTensor:
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
|
||||
default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||
if pad_token_id is None:
|
||||
return default_attention_mask
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
|
||||
eos_token_id, pad_token_id
|
||||
).any()
|
||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||
attention_mask_from_padding = input_ids.ne(pad_token_id).long()
|
||||
|
||||
attention_mask = (
|
||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
def _prepare_inputs_embeds_for_generation(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -2315,6 +2340,12 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
||||
kwargs_depth_decoder = depth_decoder_generation_config
|
||||
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
if attention_mask is None:
|
||||
attention_mask = self._prepare_attention_mask_for_generation(
|
||||
input_ids=input_ids,
|
||||
generation_config=generation_config,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
(
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
@ -2497,11 +2528,11 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_cache_shape(),
|
||||
dtype=self.lm_head.weight.dtype,
|
||||
dtype=self.decoder.lm_head.weight.dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user