mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix default attention mask of generate in MoshiForConditionalGeneration (#36171)
This commit is contained in:
parent
27d1707586
commit
e18f233f6c
@ -2099,6 +2099,31 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|||||||
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
|
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _prepare_attention_mask_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
pad_token_id = generation_config.pad_token_id
|
||||||
|
eos_token_id = generation_config.eos_token_id
|
||||||
|
|
||||||
|
default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||||
|
if pad_token_id is None:
|
||||||
|
return default_attention_mask
|
||||||
|
|
||||||
|
is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
|
||||||
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
|
||||||
|
eos_token_id, pad_token_id
|
||||||
|
).any()
|
||||||
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
|
attention_mask_from_padding = input_ids.ne(pad_token_id).long()
|
||||||
|
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
|
)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
def _prepare_inputs_embeds_for_generation(
|
def _prepare_inputs_embeds_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -2315,6 +2340,12 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|||||||
kwargs_depth_decoder = depth_decoder_generation_config
|
kwargs_depth_decoder = depth_decoder_generation_config
|
||||||
|
|
||||||
attention_mask = kwargs.pop("attention_mask", None)
|
attention_mask = kwargs.pop("attention_mask", None)
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = self._prepare_attention_mask_for_generation(
|
||||||
|
input_ids=input_ids,
|
||||||
|
generation_config=generation_config,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
(
|
(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -2497,11 +2528,11 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
|||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=past_key_values.get_max_cache_shape(),
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
dtype=self.lm_head.weight.dtype,
|
dtype=self.decoder.lm_head.weight.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user