diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index ae9a3fd804d..d1f4f8a9cab 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -2099,6 +2099,31 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions, ) + def _prepare_attention_mask_for_generation( + self, + input_ids: torch.LongTensor, + generation_config: GenerationConfig, + kwargs: Dict[str, Any], + ) -> torch.LongTensor: + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id + + default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) + if pad_token_id is None: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin( + eos_token_id, pad_token_id + ).any() + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = input_ids.ne(pad_token_id).long() + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + def _prepare_inputs_embeds_for_generation( self, input_ids: Optional[torch.LongTensor] = None, @@ -2315,6 +2340,12 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): kwargs_depth_decoder = depth_decoder_generation_config attention_mask = kwargs.pop("attention_mask", None) + if attention_mask is None: + attention_mask = self._prepare_attention_mask_for_generation( + input_ids=input_ids, + generation_config=generation_config, + kwargs=kwargs, + ) ( inputs_embeds, input_ids, @@ -2497,11 +2528,11 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): batch_size, sequence_length = input_ids.shape device = input_ids.device - attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), - dtype=self.lm_head.weight.dtype, + dtype=self.decoder.lm_head.weight.dtype, device=device, cache_position=cache_position, batch_size=batch_size,