diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 9dceb3163e8..4fcab3c7849 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): all_hidden_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])): - if output_hidden_states: + if inputs["output_hidden_states"]: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = h( hidden_states, @@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): inputs["attention_mask"], inputs["head_mask"][i], inputs["use_cache"], - output_attentions, + inputs["output_attentions"], training=inputs["training"], ) hidden_states, present = outputs[:2] @@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): if inputs["use_cache"]: presents = presents + (present,) - if output_attentions: + if inputs["output_attentions"]: all_attentions = all_attentions + (outputs[2],) hidden_states = self.layernorm(hidden_states)