mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 20:30:41 +06:00
fix mistral
and mistral3
tests (#38978)
* fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
b6b4d43d6d
commit
2ce02b98bf
@ -112,6 +112,7 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
class MistralIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
@ -121,6 +122,9 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls.device_properties = get_device_properties()
|
||||
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@ -256,7 +260,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"
|
||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% Sriracha. I love it on everything. I have it on my"
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
@ -273,7 +277,6 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
@ -339,23 +342,32 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
@require_torch_accelerator
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
model_name = "mistralai/Mistral-7B-v0.1"
|
||||
_model = None
|
||||
model = None
|
||||
model_dtype = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
if cls.model_dtype is None:
|
||||
cls.model_dtype = torch.float16
|
||||
if cls.model is None:
|
||||
cls.model = MistralForCausalLM.from_pretrained(cls.model_name, torch_dtype=cls.model_dtype).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del cls.model_dtype
|
||||
del cls.model
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if self.__class__._model is None:
|
||||
self.__class__._model = MistralForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=self.model_dtype
|
||||
).to(torch_device)
|
||||
return self.__class__._model
|
||||
|
||||
def setUp(self):
|
||||
self.model_dtype = torch.float16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
Loading…
Reference in New Issue
Block a user