From 7eaa90b87bcbd5013737fa183a5f7166c891fa9f Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 12 May 2025 16:51:21 +0200 Subject: [PATCH] Add AMD expectation to test_gpt2_sample (#38079) --- tests/models/gpt2/test_modeling_gpt2.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 44e342539b8..ed172fd91dd 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -20,6 +20,7 @@ import pytest from transformers import GPT2Config, is_torch_available from transformers.testing_utils import ( + Expectations, cleanup, require_flash_attn, require_torch, @@ -817,10 +818,14 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) - EXPECTED_OUTPUT_STR = ( - "Today is a nice day and if you don't know anything about the state of play during your holiday" - ) - self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + expected_outputs = Expectations( + { + ("rocm", None): 'Today is a nice day and we can do this again."\n\nDana said that she will', + ("cuda", None): "Today is a nice day and if you don't know anything about the state of play during your holiday", + } + ) # fmt: skip + EXPECTED_OUTPUT = expected_outputs.get_expectation() + self.assertEqual(output_str, EXPECTED_OUTPUT) self.assertTrue( all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))) ) # token_type_ids should change output