fix mistral and mistral3 tests (#38978)

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-23 17:07:18 +02:00 committed by GitHub
parent b6b4d43d6d
commit 2ce02b98bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -112,6 +112,7 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch_accelerator
@require_read_token
class MistralIntegrationTest(unittest.TestCase):
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
# Depending on the hardware we get different logits / generations
@ -121,6 +122,9 @@ class MistralIntegrationTest(unittest.TestCase):
def setUpClass(cls):
cls.device_properties = get_device_properties()
def setUp(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@ -256,7 +260,7 @@ class MistralIntegrationTest(unittest.TestCase):
@slow
def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I love it on everything. Im not a big"
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% Sriracha. I love it on everything. I have it on my"
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
@ -273,7 +277,6 @@ class MistralIntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
@ -339,23 +342,32 @@ class MistralIntegrationTest(unittest.TestCase):
@require_torch_accelerator
class Mask4DTestHard(unittest.TestCase):
model_name = "mistralai/Mistral-7B-v0.1"
_model = None
model = None
model_dtype = None
@classmethod
def setUpClass(cls):
cleanup(torch_device, gc_collect=True)
if cls.model_dtype is None:
cls.model_dtype = torch.float16
if cls.model is None:
cls.model = MistralForCausalLM.from_pretrained(cls.model_name, torch_dtype=cls.model_dtype).to(
torch_device
)
@classmethod
def tearDownClass(cls):
del cls.model_dtype
del cls.model
cleanup(torch_device, gc_collect=True)
def setUp(self):
cleanup(torch_device, gc_collect=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@property
def model(self):
if self.__class__._model is None:
self.__class__._model = MistralForCausalLM.from_pretrained(
self.model_name, torch_dtype=self.model_dtype
).to(torch_device)
return self.__class__._model
def setUp(self):
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item