From 132402d752044301b37e54405832738b16f49df6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 20 Jun 2022 11:07:46 +0100 Subject: [PATCH] TF: BART compatible with XLA generation (#17479) * Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus --- .../models/bart/modeling_tf_bart.py | 113 ++++++++++++++++-- .../blenderbot/modeling_tf_blenderbot.py | 46 +++++-- .../modeling_tf_blenderbot_small.py | 48 ++++++-- .../models/clip/modeling_tf_clip.py | 2 +- .../models/hubert/modeling_tf_hubert.py | 2 +- .../models/led/modeling_tf_led.py | 7 +- .../models/marian/modeling_tf_marian.py | 46 +++++-- .../models/mbart/modeling_tf_mbart.py | 49 ++++++-- .../models/opt/modeling_tf_opt.py | 7 +- .../models/pegasus/modeling_tf_pegasus.py | 46 +++++-- .../modeling_tf_speech_to_text.py | 5 +- src/transformers/models/t5/modeling_tf_t5.py | 15 +-- .../models/wav2vec2/modeling_tf_wav2vec2.py | 3 +- ...tf_{{cookiecutter.lowercase_modelname}}.py | 2 +- tests/models/bart/test_modeling_tf_bart.py | 110 ++++++++++++++--- tests/models/gpt2/test_modeling_tf_gpt2.py | 2 +- tests/models/t5/test_modeling_tf_t5.py | 2 +- tests/test_modeling_tf_common.py | 2 +- 18 files changed, 421 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 8f8586c7913..0a150b6ea87 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -20,6 +20,7 @@ from typing import Optional, Tuple, Union import numpy as np import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import ( @@ -87,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -99,7 +101,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -123,12 +125,19 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, + input_shape: Optional[tf.TensorShape] = None, + past_key_values_length: int = 0, + position_ids: Optional[tf.Tensor] = None, + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return super().call(positions + self.offset) + return super().call(position_ids + self.offset) class TFBartAttention(tf.keras.layers.Layer): @@ -599,6 +608,9 @@ BART_INPUTS_DOCSTRING = r""" for denoising pre-training following the paper. decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -838,6 +850,7 @@ class TFBartDecoder(tf.keras.layers.Layer): input_ids: Optional[TFModelInputType] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, @@ -866,6 +879,9 @@ class TFBartDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -922,7 +938,10 @@ class TFBartDecoder(tf.keras.layers.Layer): past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1058,6 +1077,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, @@ -1112,6 +1132,7 @@ class TFBartMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1173,6 +1194,7 @@ class TFBartModel(TFBartPretrainedModel): attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, @@ -1193,6 +1215,7 @@ class TFBartModel(TFBartPretrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1278,6 +1301,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, @@ -1320,6 +1344,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1375,6 +1400,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1382,22 +1408,95 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): + # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored + # quite some duplicated code patterns it seems + past = outputs.past_key_values + is_past_initialized = model_kwargs.pop("past", None) is not None + decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None) + batch_size = past[0][0].shape[0] + + if not is_past_initialized: + # past[0][0].shape[2] is seq_length of prompt + # The padded version of `past` requires only `max_length - 1` steps along the time dimension. + num_padding_values = max_length - past[0][0].shape[2] - 1 + # prepare the padding tensor for `tf.pad`. + # `shape=(4, 2)` because each tensor element in `past` has `rank=4`. + # `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward). + padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2)) + + new_past = () + for past_layer in past: + new_past_layer = list(past_layer) + for i in range(len(new_past_layer[:2])): + new_past_layer[i] = tf.pad(past_layer[i], padding_values) + new_past += (tuple(new_past_layer),) + + # 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, + # ones for the actual input_ids + decoder_attention_mask = tf.concat( + [ + tf.ones((batch_size, 1), dtype=tf.int32), + tf.zeros((batch_size, num_padding_values), dtype=tf.int32), + tf.ones((batch_size, 1), dtype=tf.int32), + ], + axis=1, + ) + else: + slice_start_base = tf.constant([0, 0, 1, 0]) + decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype) + # correct 5 here + new_past_index = current_pos - 1 + + new_past = () + for past_layer in past: + new_past_layer = list(past_layer) + for i in range(len(new_past_layer[:2])): + update_slice = past_layer[i][:, :, -1:] + # Write the last slice to the first open location in the padded past array + # and then truncate the last slice off the array + new_past_layer[i] = dynamic_update_slice( + past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index + ) + new_past += (tuple(new_past_layer),) + + update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index + decoder_attention_mask = dynamic_update_slice( + decoder_attention_mask, decoder_attention_mask_update_slice, update_start + ) + + # set `decoder_attention_mask` and `past` + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + model_kwargs["past"] = new_past + + return model_kwargs + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 24ed4baa969..2bede02ab25 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -89,7 +89,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -102,7 +103,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -123,12 +124,14 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings): def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): super().__init__(num_embeddings, embedding_dim, **kwargs) - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] - - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return super().call(positions) + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return super().call(position_ids) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot @@ -582,6 +585,9 @@ BLENDERBOT_INPUTS_DOCSTRING = r""" `past_key_values`). decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -827,6 +833,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): input_ids=None, inputs_embeds=None, attention_mask=None, + position_ids=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, @@ -855,6 +862,9 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -916,7 +926,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1049,6 +1062,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1092,6 +1106,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1166,6 +1181,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1185,6 +1201,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1285,6 +1302,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1326,6 +1344,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1383,6 +1402,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1390,16 +1410,26 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 157af644684..501b3e9df10 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -123,12 +124,14 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings): def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): super().__init__(num_embeddings, embedding_dim, **kwargs) - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] - - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return super().call(positions) + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return super().call(position_ids) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->BlenderbotSmall @@ -587,6 +590,9 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r""" `past_key_values`). decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -831,6 +837,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): input_ids=None, inputs_embeds=None, attention_mask=None, + position_ids=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, @@ -859,6 +866,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -938,7 +948,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) hidden_states = self.layernorm_embedding(inputs_embeds) + positions hidden_states = self.dropout(hidden_states, training=training) @@ -1050,6 +1063,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1094,6 +1108,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1152,6 +1167,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1172,6 +1188,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1256,6 +1273,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1296,11 +1314,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1355,6 +1374,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1362,16 +1382,26 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py index 6ba83f04b84..8635c7d7602 100644 --- a/src/transformers/models/clip/modeling_tf_clip.py +++ b/src/transformers/models/clip/modeling_tf_clip.py @@ -59,7 +59,7 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index d659d2cacb5..bc442ad4ae6 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -263,7 +263,7 @@ def _compute_mask_indices( # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 83a71a0dfe8..94f1c7cbc48 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -74,11 +74,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to return shifted_input_ids +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -90,7 +92,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index d5f41abe133..d356b4f8424 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -162,12 +163,14 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer): tf.stop_gradient(table) return table - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] - - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return tf.gather(self.weight, positions) + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, position_ids) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian @@ -628,6 +631,9 @@ MARIAN_INPUTS_DOCSTRING = r""" `past_key_values`). decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -870,6 +876,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): input_ids=None, inputs_embeds=None, attention_mask=None, + position_ids=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, @@ -898,6 +905,9 @@ class TFMarianDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -960,7 +970,10 @@ class TFMarianDecoder(tf.keras.layers.Layer): past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1091,6 +1104,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1138,6 +1152,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1196,6 +1211,7 @@ class TFMarianModel(TFMarianPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1215,6 +1231,7 @@ class TFMarianModel(TFMarianPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1299,6 +1316,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1341,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1398,6 +1417,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1405,16 +1425,26 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index fa19d711a31..b33de11113a 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -86,7 +86,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -99,7 +100,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -124,12 +125,19 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs) - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, + input_shape: Optional[tf.TensorShape] = None, + past_key_values_length: int = 0, + position_ids: Optional[tf.Tensor] = None, + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return super().call(positions + self.offset) + return super().call(position_ids + self.offset) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart @@ -569,6 +577,9 @@ MBART_INPUTS_DOCSTRING = r""" for denoising pre-training following the paper. decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -853,6 +864,7 @@ class TFMBartDecoder(tf.keras.layers.Layer): input_ids: TFModelInputType = None, inputs_embeds: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, @@ -883,6 +895,9 @@ class TFMBartDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -945,7 +960,10 @@ class TFMBartDecoder(tf.keras.layers.Layer): past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1079,6 +1097,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1129,6 +1148,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1187,6 +1207,7 @@ class TFMBartModel(TFMBartPreTrainedModel): attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1207,6 +1228,7 @@ class TFMBartModel(TFMBartPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1291,6 +1313,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo attention_mask: Optional[tf.Tensor] = None, decoder_input_ids: Optional[tf.Tensor] = None, decoder_attention_mask: Optional[tf.Tensor] = None, + decoder_position_ids: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None, decoder_head_mask: Optional[tf.Tensor] = None, cross_attn_head_mask: Optional[tf.Tensor] = None, @@ -1331,6 +1354,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1388,6 +1412,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1395,16 +1420,26 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py index 0c3de0ce206..4353020485a 100644 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ b/src/transformers/models/opt/modeling_tf_opt.py @@ -56,11 +56,13 @@ _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] LARGE_NEGATIVE = -1e8 +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -72,7 +74,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 2c5696f94d3..578369e774e 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -163,12 +164,14 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer): tf.stop_gradient(table) return table - def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None + ): """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = input_shape[:2] - - positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") - return tf.gather(self.weight, positions) + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, position_ids) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus @@ -627,6 +630,9 @@ PEGASUS_INPUTS_DOCSTRING = r""" `past_key_values`). decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -876,6 +882,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer): input_ids=None, inputs_embeds=None, attention_mask=None, + position_ids=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, @@ -904,6 +911,9 @@ class TFPegasusDecoder(tf.keras.layers.Layer): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -966,7 +976,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer): past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1099,6 +1112,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1146,6 +1160,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer): decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, @@ -1204,6 +1219,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1224,6 +1240,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1308,6 +1325,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_position_ids=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1350,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, @@ -1407,6 +1426,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1414,16 +1434,26 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua encoder_outputs=None, **kwargs ): + # cut decoder_input_ids if past is used if past is not None: decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index f61ddd7fed0..b8be2b6f95e 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -90,7 +90,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) @@ -103,7 +104,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i # Copied from transformers.models.bart.modeling_tf_bart._expand_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 77a65557daa..5163e33f34e 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1504,21 +1504,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored # quite some duplicated code patterns it seems - # also the `attention_mask` is currently used in a somewhat hacky to - # correctly influence the `past_key_values` - not sure if this is the way to go - # Let's keep that for a future PR. past = outputs.past_key_values is_past_initialized = model_kwargs.pop("past", None) is not None decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None) batch_size = past[0][0].shape[0] if not is_past_initialized: - # past[0].shape[3] is seq_length of prompt + # past[0].shape[2] is seq_length of prompt num_padding_values = max_length - past[0][0].shape[2] - 1 - - padding_values = np.zeros((4, 2), dtype=np.int32) - padding_values[2, 1] = num_padding_values - padding_values = tf.constant(padding_values) + padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2)) new_past = () for past_layer in past: @@ -1527,7 +1521,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling new_past_layer[i] = tf.pad(past_layer[i], padding_values) new_past += (tuple(new_past_layer),) - # 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids + # 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, + # ones for the actual input_ids decoder_attention_mask = tf.concat( [ tf.ones((batch_size, 1), dtype=tf.int32), @@ -1559,7 +1554,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling decoder_attention_mask, decoder_attention_mask_update_slice, update_start ) - # set `attention_mask` and `past` + # set `decoder_attention_mask` and `past` model_kwargs["decoder_attention_mask"] = decoder_attention_mask model_kwargs["past"] = new_past diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index 567f20040b9..bf229faade9 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -297,7 +297,8 @@ def _compute_mask_indices( return spec_aug_mask -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index f5c40b27d61..c5224bfccb3 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -1716,7 +1716,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) -def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index 1e599c6b1ba..0df55500db3 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -125,8 +125,22 @@ class TFBartModelTester: next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1) - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0] + decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32) + output_from_no_past = model( + next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids + ) + output_from_no_past = output_from_no_past[0] + + decoder_position_ids = ( + tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2] + ) + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + position_ids=decoder_position_ids, + ) + output_from_past = output_from_past[0] self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) @@ -138,6 +152,23 @@ class TFBartModelTester: # test that outputs are equal for slice tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args): + config.eos_token_id = None # Generate until max length + config.max_length = 10 + config.do_sample = False + config.num_beams = 1 + model = TFBartForConditionalGeneration(config=config) + + # make sure there are no pad tokens in prompt + input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1) + + generated = model.generate(input_ids) + + generate_xla = tf.function(model.generate, jit_compile=True) + generated_xla = generate_xla(input_ids) + + self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist()) + def prepare_bart_inputs_dict( config, @@ -279,25 +310,15 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC models_equal = False self.assertTrue(models_equal) + def test_bart_model_xla_generate_fast(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"]) + def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass -def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): - """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" - if a is None and b is None: - return True - try: - if tf.debugging.assert_near(a, b, atol=atol): - return True - raise - except Exception: - if len(prefix) > 0: - prefix = f"{prefix}: " - raise AssertionError(f"{prefix}{a} != {b}") - - def _long_tensor(tok_lst): return tf.constant(tok_lst, dtype=tf.int32) @@ -682,6 +703,63 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] assert result == EXPECTED + def test_xsum_1_1_xla_greedy_generation(self): + # TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA + # comparisons to the other tests after enabling XLA beam search. + # Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA + model = self.xsum_1_1_model + assert model.model.decoder.embed_tokens._layer == model.model.shared + ARTICLE = ( + "The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes." + ) + EXPECTED = ( + " The International Criminal Court (ICC) has announced that it is to be investigated by the International" + " Criminal Court (ICC) over claims that the Palestinian genocide." + ) + dct = self.tok(ARTICLE, return_tensors="tf") + generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0) + result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] + assert result == EXPECTED + + xla_generate = tf.function(model.generate, jit_compile=True) + generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0) + result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] + assert result == EXPECTED + def test_xsum_1_1_batch_generation(self): batch = self.tok( [ diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index 93b48ce8f29..efa3f0ac1c0 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -295,7 +295,7 @@ class TFGPT2ModelTester: self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args): - config.eos_token_id = None + config.eos_token_id = None # Generate until max length config.max_length = 10 model = TFGPT2LMHeadModel(config=config) diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index 5ad746e34fc..e815fd7ad07 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -228,7 +228,7 @@ class TFT5ModelTester: tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args): - config.eos_token_id = None + config.eos_token_id = None # Generate until max length config.max_length = 10 config.do_sample = False config.num_beams = 1 diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 908d0722207..3a8a9c80cfb 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -214,10 +214,10 @@ class TFModelTesterMixin: "decoder_input_ids", "decoder_attention_mask", ] + expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else []) expected_arg_names.extend( ["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else [] ) - # Necessary to handle BART with newly added cross_attn_head_mask expected_arg_names.extend( ["cross_attn_head_mask", "encoder_outputs"] if "cross_attn_head_mask" in arg_names