mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[TF Led] Fix wrong decoder attention mask behavior (#9601)
* fix tf led * remove loop file
This commit is contained in:
parent
85788bae5c
commit
90ca8d36e9
@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
||||
inputs["attention_mask"] = tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
||||
)
|
||||
inputs["attention_mask"] = tf.concat(
|
||||
[
|
||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
||||
inputs["attention_mask"],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
inputs["attention_mask"] = tf.ones(
|
||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user