fix tests

This commit is contained in:
thomwolf 2019-02-09 17:07:12 +01:00
parent f0bf81e141
commit 9bdcba53fd

View File

@ -93,7 +93,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
if self.use_labels:
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float()
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,