Fix flax gpt2 hidden states (#13109)

* Fix inconsistency of the last element in hidden_states between PyTorch/Flax GPT2(Neo) (#13102)

* Fix missing elements in outputs tuple

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Fix local variable 'all_hidden_states' referenced before assignment

* Fix by returning tuple containing None values

* Fix quality

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Yih-Dar 2021-08-13 10:45:53 +02:00 committed by GitHub
parent d8fb278a2c
commit a04d4bf2d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 34 deletions

View File

@ -24,7 +24,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging
from .configuration_gpt2 import GPT2Config
@ -458,20 +458,10 @@ class FlaxGPT2BlockCollection(nn.Module):
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# this contains possible `None` values - `FlaxGPT2Module` will filter them out
outputs = (hidden_states, all_hidden_states, all_attentions)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
return outputs
class FlaxGPT2Module(nn.Module):
@ -527,13 +517,19 @@ class FlaxGPT2Module(nn.Module):
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return (hidden_states,) + outputs[1:]
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
hidden_states=outputs[1],
attentions=outputs[-1],
)

View File

@ -25,7 +25,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging
from .configuration_gpt_neo import GPTNeoConfig
@ -488,20 +488,10 @@ class FlaxGPTNeoBlockCollection(nn.Module):
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# this contains possible `None` values - `FlaxGPTNeoModule` will filter them out
outputs = (hidden_states, all_hidden_states, all_attentions)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
return outputs
class FlaxGPTNeoModule(nn.Module):
@ -557,13 +547,22 @@ class FlaxGPTNeoModule(nn.Module):
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return (hidden_states,) + outputs[1:]
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
hidden_states=outputs[1],
attentions=outputs[-1],
)