mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Fix TF T5 only encoder model with booleans (#8925)
This commit is contained in:
parent
dcd3046f98
commit
71688a8889
@ -1336,14 +1336,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
# Putting this before breaks tf compilation.
|
||||
output_attentions = (
|
||||
output_attentions if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if inputs["output_hidden_states"] is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
||||
# TF refuses to compile anymore.
|
||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
||||
@ -1481,6 +1473,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@ -1492,12 +1485,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
output_attentions = inputs["output_attentions"] if output_attentions else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if output_hidden_states else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if inputs["return_dict"] is not None else self.config.return_dict
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@ -1507,17 +1494,17 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
return encoder_outputs
|
||||
|
||||
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states):
|
||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
||||
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
||||
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions):
|
||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
||||
encoder_outputs = encoder_outputs + (None,)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
|
Loading…
Reference in New Issue
Block a user