Fix PT-TF equivalence test for GPT1 (#22586)

* Re-enable skipped test and fix the hidden state shape issue

* Actually fix the bug instead of just doing something wrong
This commit is contained in:
Matt 2023-04-05 13:16:00 +01:00 committed by GitHub
parent 0684284911
commit 2a91a9ef66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
if return_dict and output_hidden_states:
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
else:
all_hidden_states = None
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1)
@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
return TFOpenAIGPTDoubleHeadsModelOutput(
logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
hidden_states=all_hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model = OpenAIGPTModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip("Fix me Matt")
def test_pt_tf_model_equivalence(self):
pass
@require_torch
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):