TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)

This commit is contained in:
Joao Gante 2022-09-14 18:19:15 +01:00 committed by GitHub
parent 693ba2cc79
commit 31be02f14b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 720 additions and 971 deletions

View File

@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -229,30 +228,24 @@ class TFBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -225,30 +224,24 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -225,30 +224,24 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids

View File

@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly():
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
if causal:
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
return mask, attn_mask
@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# check inputs
# assert shape_list(lengths)[0] == bs
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0)
position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None and tf.executing_eagerly():
if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]

View File

@ -816,30 +816,24 @@ class TFHubertAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)

View File

@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
)
# "Verify that `labels` has only positive values and -100"
if tf.executing_eagerly():
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly():
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
# normalize query
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=(
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=(
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
),
)
# compute global attn indices required through out forward fn
(
@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
)
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=(
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f" {shape_list(key)}"
),
)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=(
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f" {shape_list(key)}"
),
)
chunks_count = seq_len // window_overlap - 1
@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
),
)
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
),
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=(
"global_attn_scores have the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {shape_list(global_attn_scores)}."
),
)
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=(
"global_attn_scores have the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {shape_list(global_attn_scores)}."
),
)
global_attn_scores = tf.reshape(
global_attn_scores,
@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=(
"global_attn_output tensor has the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {shape_list(global_attn_output)}."
),
)
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=(
"global_attn_output tensor has the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {shape_list(global_attn_output)}."
),
)
global_attn_output = tf.reshape(
global_attn_output,
@ -1091,26 +1075,24 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
attention_mask, dtype=attn_weights.dtype
@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
hidden_states = layer_outputs[0]
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
all_attentions = all_global_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),

View File

@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
if tf.executing_eagerly():
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
# normalize query
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# pad local attention probs
attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=(
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=(
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
),
)
# compute global attn indices required through out forward fn
(
@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=(
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f" {shape_list(key)}"
),
)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=(
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f" {shape_list(key)}"
),
)
chunks_count = seq_len // window_overlap - 1
@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
batch_size, seq_len, num_heads, head_dim = shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
),
)
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=(
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
),
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=(
"global_attn_scores have the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {shape_list(global_attn_scores)}."
),
)
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=(
"global_attn_scores have the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
f" {shape_list(global_attn_scores)}."
),
)
global_attn_scores = tf.reshape(
global_attn_scores,
@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# apply layer head masking
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=(
"global_attn_output tensor has the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {shape_list(global_attn_output)}."
),
)
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=(
"global_attn_output tensor has the wrong size. Size should be"
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
f" {shape_list(global_attn_output)}."
),
)
global_attn_output = tf.reshape(
global_attn_output,

View File

@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -264,30 +263,24 @@ class TFMarianAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -232,30 +232,24 @@ class TFMBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -206,30 +206,24 @@ class TFOPTAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -265,30 +264,24 @@ class TFPegasusAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -324,30 +323,24 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
training=training,
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
next_decoder_cache = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids

View File

@ -852,30 +852,24 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)

View File

@ -239,30 +239,24 @@ class TFXGLMAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {shape_list(attn_weights)}"
),
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(attention_mask)}"
),
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {shape_list(attn_output)}"
),
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
next_decoder_cache = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),

View File

@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
# sanity check
# assert shape_list(mask) == [bs, slen]
if tf.executing_eagerly():
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
if causal:
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
return mask, attn_mask
@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
# assert shape_list(lengths)[0] == bs
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.expand_dims(tf.range(slen), axis=0)
position_ids = tf.tile(position_ids, (bs, 1))
if tf.executing_eagerly():
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None and tf.executing_eagerly():
if langs is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]

View File

@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),