diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index d98fc441312..5048930b561 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -341,5 +341,5 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase): 447, ] # the president the president the president the president the president the president the president the president the president the president # TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference - output_ids = model.generate(input_ids) - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids, do_sample=False) + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)