mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix flax gpt2 hidden states (#13109)
* Fix inconsistency of the last element in hidden_states between PyTorch/Flax GPT2(Neo) (#13102) * Fix missing elements in outputs tuple * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Fix local variable 'all_hidden_states' referenced before assignment * Fix by returning tuple containing None values * Fix quality Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
d8fb278a2c
commit
a04d4bf2d7
@ -24,7 +24,7 @@ from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
@ -458,20 +458,10 @@ class FlaxGPT2BlockCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
# this contains possible `None` values - `FlaxGPT2Module` will filter them out
|
||||
outputs = (hidden_states, all_hidden_states, all_attentions)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPT2Module(nn.Module):
|
||||
@ -527,13 +517,19 @@ class FlaxGPT2Module(nn.Module):
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = outputs[1] + (hidden_states,)
|
||||
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
hidden_states=outputs[1],
|
||||
attentions=outputs[-1],
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_gpt_neo import GPTNeoConfig
|
||||
@ -488,20 +488,10 @@ class FlaxGPTNeoBlockCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
# this contains possible `None` values - `FlaxGPTNeoModule` will filter them out
|
||||
outputs = (hidden_states, all_hidden_states, all_attentions)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPTNeoModule(nn.Module):
|
||||
@ -557,13 +547,22 @@ class FlaxGPTNeoModule(nn.Module):
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = outputs[1] + (hidden_states,)
|
||||
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
hidden_states=outputs[1],
|
||||
attentions=outputs[-1],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user