This commit is contained in:
Julien Plu 2021-01-04 15:56:51 +01:00 committed by GitHub
parent c581d8af7a
commit 6c03d4ac70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
hidden_states,
@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["use_cache"],
output_attentions,
inputs["output_attentions"],
training=inputs["training"],
)
hidden_states, present = outputs[:2]
@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if inputs["use_cache"]:
presents = presents + (present,)
if output_attentions:
if inputs["output_attentions"]:
all_attentions = all_attentions + (outputs[2],)
hidden_states = self.layernorm(hidden_states)