Fix TF T5 only encoder model with booleans (#8925)

This commit is contained in:
Lysandre Debut 2020-12-04 12:28:47 -05:00 committed by GitHub
parent dcd3046f98
commit 71688a8889
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1336,14 +1336,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
return ((loss,) + output) if loss is not None else output
# Putting this before breaks tf compilation.
output_attentions = (
output_attentions if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if inputs["output_hidden_states"] is not None else self.config.output_hidden_states
)
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore.
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
@ -1481,6 +1473,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
@ -1492,12 +1485,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
kwargs_call=kwargs,
)
output_attentions = inputs["output_attentions"] if output_attentions else self.config.output_attentions
output_hidden_states = (
inputs["output_hidden_states"] if output_hidden_states else self.config.output_hidden_states
)
return_dict = return_dict if inputs["return_dict"] is not None else self.config.return_dict
encoder_outputs = self.encoder(
input_ids,
attention_mask=inputs["attention_mask"],
@ -1507,17 +1494,17 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
head_mask=head_mask,
past_key_values=None,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
training=inputs["training"],
)
if not return_dict:
if not inputs["return_dict"]:
return encoder_outputs
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states):
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions):
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
encoder_outputs = encoder_outputs + (None,)
return TFBaseModelOutput(