TF: BART compatible with XLA generation (#17479)

* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
Joao Gante 2022-06-20 11:07:46 +01:00 committed by GitHub
parent 6589e510fa
commit 132402d752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 421 additions and 86 deletions

View File

@ -20,6 +20,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
@ -87,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -99,7 +101,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -123,12 +125,19 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self,
input_shape: Optional[tf.TensorShape] = None,
past_key_values_length: int = 0,
position_ids: Optional[tf.Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions + self.offset)
return super().call(position_ids + self.offset)
class TFBartAttention(tf.keras.layers.Layer):
@ -599,6 +608,9 @@ BART_INPUTS_DOCSTRING = r"""
for denoising pre-training following the paper.
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -838,6 +850,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
input_ids: Optional[TFModelInputType] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
@ -866,6 +879,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -922,7 +938,10 @@ class TFBartDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1058,6 +1077,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
@ -1112,6 +1132,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1173,6 +1194,7 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
@ -1193,6 +1215,7 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1278,6 +1301,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
@ -1320,6 +1344,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1375,6 +1400,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1382,22 +1408,95 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]
if not is_past_initialized:
# past[0][0].shape[2] is seq_length of prompt
# The padded version of `past` requires only `max_length - 1` steps along the time dimension.
num_padding_values = max_length - past[0][0].shape[2] - 1
# prepare the padding tensor for `tf.pad`.
# `shape=(4, 2)` because each tensor element in `past` has `rank=4`.
# `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward).
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)
# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

View File

@ -89,7 +89,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -102,7 +103,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -123,12 +124,14 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions)
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot
@ -582,6 +585,9 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -827,6 +833,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@ -855,6 +862,9 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -916,7 +926,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1049,6 +1062,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1092,6 +1106,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1166,6 +1181,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1185,6 +1201,7 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1285,6 +1302,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1326,6 +1344,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1383,6 +1402,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1390,16 +1410,26 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,

View File

@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -123,12 +124,14 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions)
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->BlenderbotSmall
@ -587,6 +590,9 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -831,6 +837,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@ -859,6 +866,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -938,7 +948,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
hidden_states = self.layernorm_embedding(inputs_embeds) + positions
hidden_states = self.dropout(hidden_states, training=training)
@ -1050,6 +1063,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1094,6 +1108,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1152,6 +1167,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1172,6 +1188,7 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1256,6 +1273,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1296,11 +1314,12 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
@ -1355,6 +1374,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1362,16 +1382,26 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,

View File

@ -59,7 +59,7 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -263,7 +263,7 @@ def _compute_mask_indices(
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -74,11 +74,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
return shifted_input_ids
# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -90,7 +92,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -162,12 +163,14 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
tf.stop_gradient(table)
return table
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, positions)
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian
@ -628,6 +631,9 @@ MARIAN_INPUTS_DOCSTRING = r"""
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -870,6 +876,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@ -898,6 +905,9 @@ class TFMarianDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -960,7 +970,10 @@ class TFMarianDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1091,6 +1104,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1138,6 +1152,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1196,6 +1211,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1215,6 +1231,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1299,6 +1316,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1341,6 +1359,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1398,6 +1417,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1405,16 +1425,26 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,

View File

@ -86,7 +86,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -99,7 +100,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -124,12 +125,19 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self,
input_shape: Optional[tf.TensorShape] = None,
past_key_values_length: int = 0,
position_ids: Optional[tf.Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions + self.offset)
return super().call(position_ids + self.offset)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
@ -569,6 +577,9 @@ MBART_INPUTS_DOCSTRING = r"""
for denoising pre-training following the paper.
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -853,6 +864,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
input_ids: TFModelInputType = None,
inputs_embeds: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
@ -883,6 +895,9 @@ class TFMBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -945,7 +960,10 @@ class TFMBartDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1079,6 +1097,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1129,6 +1148,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1187,6 +1207,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1207,6 +1228,7 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1291,6 +1313,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
decoder_position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
@ -1331,6 +1354,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1388,6 +1412,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1395,16 +1420,26 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,

View File

@ -56,11 +56,13 @@ _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -72,7 +74,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -88,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -101,7 +102,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@ -163,12 +164,14 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
tf.stop_gradient(table)
return table
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: Optional[tf.Tensor] = None
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, positions)
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, position_ids)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
@ -627,6 +630,9 @@ PEGASUS_INPUTS_DOCSTRING = r"""
`past_key_values`).
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
@ -876,6 +882,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
@ -904,6 +911,9 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
@ -966,7 +976,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@ -1099,6 +1112,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1146,6 +1160,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
@ -1204,6 +1219,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1224,6 +1240,7 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1308,6 +1325,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1350,6 +1368,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
@ -1407,6 +1426,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
@ -1414,16 +1434,26 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
encoder_outputs=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,

View File

@ -90,7 +90,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
@ -103,7 +104,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -1504,21 +1504,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]
if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
# past[0].shape[2] is seq_length of prompt
num_padding_values = max_length - past[0][0].shape[2] - 1
padding_values = np.zeros((4, 2), dtype=np.int32)
padding_values[2, 1] = num_padding_values
padding_values = tf.constant(padding_values)
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))
new_past = ()
for past_layer in past:
@ -1527,7 +1521,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
@ -1559,7 +1554,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)
# set `attention_mask` and `past`
# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

View File

@ -297,7 +297,8 @@ def _compute_mask_indices(
return spec_aug_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -1716,7 +1716,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""

View File

@ -125,8 +125,22 @@ class TFBartModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
output_from_no_past = model(
next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
)
output_from_no_past = output_from_no_past[0]
decoder_position_ids = (
tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
)
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
output_from_past = output_from_past[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
@ -138,6 +152,23 @@ class TFBartModelTester:
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFBartForConditionalGeneration(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def prepare_bart_inputs_dict(
config,
@ -279,25 +310,15 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
models_equal = False
self.assertTrue(models_equal)
def test_bart_model_xla_generate_fast(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"])
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if tf.debugging.assert_near(a, b, atol=atol):
return True
raise
except Exception:
if len(prefix) > 0:
prefix = f"{prefix}: "
raise AssertionError(f"{prefix}{a} != {b}")
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
@ -682,6 +703,63 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
def test_xsum_1_1_xla_greedy_generation(self):
# TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA
# comparisons to the other tests after enabling XLA beam search.
# Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = (
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
" formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
" The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
" situation in Palestinian territories, paving the way for possible war crimes investigations against"
" Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
" the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
" body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
" Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
" acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
" immediately end their pressure, and countries that support universal acceptance of the court's treaty"
' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
" group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
' decision to join a treaty to which over 100 countries around the world are members." In January, when'
" the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
" disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
" it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
" court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
" between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
" will include alleged war crimes committed since June. The International Criminal Court was set up in"
" 2002 to prosecute genocide, crimes against humanity and war crimes."
)
EXPECTED = (
" The International Criminal Court (ICC) has announced that it is to be investigated by the International"
" Criminal Court (ICC) over claims that the Palestinian genocide."
)
dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
def test_xsum_1_1_batch_generation(self):
batch = self.tok(
[

View File

@ -295,7 +295,7 @@ class TFGPT2ModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)

View File

@ -228,7 +228,7 @@ class TFT5ModelTester:
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1

View File

@ -214,10 +214,10 @@ class TFModelTesterMixin:
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else [])
expected_arg_names.extend(
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
)
# Necessary to handle BART with newly added cross_attn_head_mask
expected_arg_names.extend(
["cross_attn_head_mask", "encoder_outputs"]
if "cross_attn_head_mask" in arg_names