mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)
This commit is contained in:
parent
693ba2cc79
commit
31be02f14b
@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -229,30 +228,24 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -225,30 +224,24 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -225,30 +224,24 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
||||
|
||||
# sanity check
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
if causal:
|
||||
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||
|
||||
return mask, attn_mask
|
||||
|
||||
@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
position_ids = tf.tile(position_ids, (bs, 1))
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None and tf.executing_eagerly():
|
||||
if langs is not None:
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(langs), [bs, slen]
|
||||
|
@ -816,30 +816,24 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
|
@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
if tf.executing_eagerly():
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
value_vectors = self.value(hidden_states)
|
||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||
)
|
||||
|
||||
# normalize query
|
||||
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
||||
@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
# pad local attention probs
|
||||
attn_scores += diagonal_mask
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_scores),
|
||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||
message=(
|
||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_scores),
|
||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||
message=(
|
||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||
),
|
||||
)
|
||||
|
||||
# compute global attn indices required through out forward fn
|
||||
(
|
||||
@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||
|
||||
@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
),
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
||||
message="Unexpected size",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||
)
|
||||
|
||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||
|
||||
@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
"""
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(query),
|
||||
shape_list(key),
|
||||
message=(
|
||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||
f" {shape_list(key)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(query),
|
||||
shape_list(key),
|
||||
message=(
|
||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||
f" {shape_list(key)}"
|
||||
),
|
||||
)
|
||||
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
|
||||
@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[:3],
|
||||
shape_list(value)[:3],
|
||||
message="value and attn_probs must have same dims (except head_dim)",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[3],
|
||||
2 * window_overlap + 1,
|
||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[:3],
|
||||
shape_list(value)[:3],
|
||||
message="value and attn_probs must have same dims (except head_dim)",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[3],
|
||||
2 * window_overlap + 1,
|
||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||
)
|
||||
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
|
||||
@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_value),
|
||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||
message="Chunked value has the wrong shape",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_value),
|
||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||
message="Chunked value has the wrong shape",
|
||||
)
|
||||
|
||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||
@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
# chunk with overlap
|
||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_hidden_states),
|
||||
[batch_size, num_output_chunks, frame_size],
|
||||
message=(
|
||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_hidden_states),
|
||||
[batch_size, num_output_chunks, frame_size],
|
||||
message=(
|
||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||
),
|
||||
)
|
||||
|
||||
chunked_hidden_states = tf.reshape(
|
||||
chunked_hidden_states,
|
||||
@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
# compute attn scores
|
||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_scores),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||
message=(
|
||||
"global_attn_scores have the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||
f" {shape_list(global_attn_scores)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_scores),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||
message=(
|
||||
"global_attn_scores have the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||
f" {shape_list(global_attn_scores)}."
|
||||
),
|
||||
)
|
||||
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores,
|
||||
@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# apply layer head masking
|
||||
if layer_head_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
# global attn output
|
||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_output),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||
message=(
|
||||
"global_attn_output tensor has the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||
f" {shape_list(global_attn_output)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_output),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||
message=(
|
||||
"global_attn_output tensor has the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||
f" {shape_list(global_attn_output)}."
|
||||
),
|
||||
)
|
||||
|
||||
global_attn_output = tf.reshape(
|
||||
global_attn_output,
|
||||
@ -1091,26 +1075,24 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
|
||||
attention_mask, dtype=attn_weights.dtype
|
||||
@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = all_global_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
value_vectors = self.value(hidden_states)
|
||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||
)
|
||||
|
||||
# normalize query
|
||||
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
||||
@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# pad local attention probs
|
||||
attn_scores += diagonal_mask
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_scores),
|
||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||
message=(
|
||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_scores),
|
||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||
message=(
|
||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||
),
|
||||
)
|
||||
|
||||
# compute global attn indices required through out forward fn
|
||||
(
|
||||
@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||
|
||||
@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
),
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
||||
message="Unexpected size",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||
)
|
||||
|
||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||
|
||||
@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
"""
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(query),
|
||||
shape_list(key),
|
||||
message=(
|
||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||
f" {shape_list(key)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(query),
|
||||
shape_list(key),
|
||||
message=(
|
||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||
f" {shape_list(key)}"
|
||||
),
|
||||
)
|
||||
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
|
||||
@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[:3],
|
||||
shape_list(value)[:3],
|
||||
message="value and attn_probs must have same dims (except head_dim)",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[3],
|
||||
2 * window_overlap + 1,
|
||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[:3],
|
||||
shape_list(value)[:3],
|
||||
message="value and attn_probs must have same dims (except head_dim)",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[3],
|
||||
2 * window_overlap + 1,
|
||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||
)
|
||||
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
|
||||
@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_value),
|
||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||
message="Chunked value has the wrong shape",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_value),
|
||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||
message="Chunked value has the wrong shape",
|
||||
)
|
||||
|
||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||
@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# chunk with overlap
|
||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_hidden_states),
|
||||
[batch_size, num_output_chunks, frame_size],
|
||||
message=(
|
||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(chunked_hidden_states),
|
||||
[batch_size, num_output_chunks, frame_size],
|
||||
message=(
|
||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||
),
|
||||
)
|
||||
|
||||
chunked_hidden_states = tf.reshape(
|
||||
chunked_hidden_states,
|
||||
@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# compute attn scores
|
||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_scores),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||
message=(
|
||||
"global_attn_scores have the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||
f" {shape_list(global_attn_scores)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_scores),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||
message=(
|
||||
"global_attn_scores have the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||
f" {shape_list(global_attn_scores)}."
|
||||
),
|
||||
)
|
||||
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores,
|
||||
@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# apply layer head masking
|
||||
if layer_head_mask is not None:
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
)
|
||||
@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# global attn output
|
||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_output),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||
message=(
|
||||
"global_attn_output tensor has the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||
f" {shape_list(global_attn_output)}."
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_output),
|
||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||
message=(
|
||||
"global_attn_output tensor has the wrong size. Size should be"
|
||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||
f" {shape_list(global_attn_output)}."
|
||||
),
|
||||
)
|
||||
|
||||
global_attn_output = tf.reshape(
|
||||
global_attn_output,
|
||||
|
@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -264,30 +263,24 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -232,30 +232,24 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -206,30 +206,24 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -265,30 +264,24 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -324,30 +323,24 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
|
||||
training=training,
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
@ -852,30 +852,24 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
|
@ -239,30 +239,24 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attn_weights)}"
|
||||
),
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {shape_list(attention_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||
f" {shape_list(layer_head_mask)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {shape_list(attn_output)}"
|
||||
),
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
||||
|
||||
# sanity check
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
if causal:
|
||||
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||
|
||||
return mask, attn_mask
|
||||
|
||||
@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
position_ids = tf.tile(position_ids, (bs, 1))
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None and tf.executing_eagerly():
|
||||
if langs is not None:
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(langs), [bs, slen]
|
||||
|
@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if head_mask is not None and tf.executing_eagerly():
|
||||
if head_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(head_mask)[0],
|
||||
len(self.layers),
|
||||
@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
present_key_values = () if use_cache else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||
if attn_mask is not None and tf.executing_eagerly():
|
||||
if attn_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_mask)[0],
|
||||
len(self.layers),
|
||||
|
Loading…
Reference in New Issue
Block a user