mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix CTRL (#9291)
This commit is contained in:
parent
c581d8af7a
commit
6c03d4ac70
@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||
if output_hidden_states:
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h(
|
||||
hidden_states,
|
||||
@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
inputs["use_cache"],
|
||||
output_attentions,
|
||||
inputs["output_attentions"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["use_cache"]:
|
||||
presents = presents + (present,)
|
||||
|
||||
if output_attentions:
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions = all_attentions + (outputs[2],)
|
||||
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user