Fix default attention mask of generate in MoshiForConditionalGeneration (#36171)

This commit is contained in:
Cyan 2025-02-21 04:53:27 +09:00 committed by GitHub
parent 27d1707586
commit e18f233f6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2099,6 +2099,31 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
)
def _prepare_attention_mask_for_generation(
self,
input_ids: torch.LongTensor,
generation_config: GenerationConfig,
kwargs: Dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
if pad_token_id is None:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
eos_token_id, pad_token_id
).any()
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = input_ids.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask
def _prepare_inputs_embeds_for_generation(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -2315,6 +2340,12 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
kwargs_depth_decoder = depth_decoder_generation_config
attention_mask = kwargs.pop("attention_mask", None)
if attention_mask is None:
attention_mask = self._prepare_attention_mask_for_generation(
input_ids=input_ids,
generation_config=generation_config,
kwargs=kwargs,
)
(
inputs_embeds,
input_ids,
@ -2497,11 +2528,11 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
batch_size, sequence_length = input_ids.shape
device = input_ids.device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
dtype=self.decoder.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,