mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix TF T5 only encoder model with booleans (#8925)
This commit is contained in:
parent
dcd3046f98
commit
71688a8889
@ -1336,14 +1336,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
# Putting this before breaks tf compilation.
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions if inputs["output_attentions"] is not None else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if inputs["output_hidden_states"] is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
||||||
# TF refuses to compile anymore.
|
# TF refuses to compile anymore.
|
||||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
||||||
@ -1481,6 +1473,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
|
config=self.config,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@ -1492,12 +1485,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
|||||||
kwargs_call=kwargs,
|
kwargs_call=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_attentions = inputs["output_attentions"] if output_attentions else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
inputs["output_hidden_states"] if output_hidden_states else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if inputs["return_dict"] is not None else self.config.return_dict
|
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
@ -1507,17 +1494,17 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=output_attentions,
|
output_attentions=inputs["output_attentions"],
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=inputs["output_hidden_states"],
|
||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not inputs["return_dict"]:
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
if not cast_bool_to_primitive(output_hidden_states, self.config.output_hidden_states):
|
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
||||||
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
||||||
if not cast_bool_to_primitive(output_attentions, self.config.output_attentions):
|
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
||||||
encoder_outputs = encoder_outputs + (None,)
|
encoder_outputs = encoder_outputs + (None,)
|
||||||
|
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutput(
|
||||||
|
Loading…
Reference in New Issue
Block a user