mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Fix PT-TF equivalence test for GPT1 (#22586)
* Re-enable skipped test and fix the hidden state shape issue * Actually fix the bug instead of just doing something wrong
This commit is contained in:
parent
0684284911
commit
2a91a9ef66
@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
if return_dict and output_hidden_states:
|
||||
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
|
||||
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
|
||||
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
|
||||
else:
|
||||
all_hidden_states = None
|
||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
return TFOpenAIGPTDoubleHeadsModelOutput(
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
model = OpenAIGPTModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("Fix me Matt")
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user