mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix test
This commit is contained in:
parent
a3274ac40b
commit
bcc9e93e6f
@ -152,9 +152,10 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
self.parent.assertEqual(self.n_layer, len(result["presents"]))
|
||||
self.parent.assertListEqual(
|
||||
list(result["presents"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
list(result["presents"][0].size()),
|
||||
[2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head])
|
||||
|
||||
def check_gpt2_lm_head_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
|
Loading…
Reference in New Issue
Block a user