Fix fsmt tests (#38904)

* fix 1

* fix 2

* fix 3

* fix 4

* fix 5

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-19 10:56:34 +02:00 committed by GitHub
parent 11738f8537
commit b949747b54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -474,7 +474,16 @@ class FSMTModelIntegrationTests(unittest.TestCase):
def get_model(self, mname): def get_model(self, mname):
if mname not in self.models_cache: if mname not in self.models_cache:
self.models_cache[mname] = FSMTForConditionalGeneration.from_pretrained(mname).to(torch_device) # The safetensors checkpoint on `facebook/wmt19-de-en` (and other repositories) has issues.
# Hub PRs are opened, see https://huggingface.co/facebook/wmt19-de-en/discussions/6
# We have asked Meta to merge them but no response yet:
# https://huggingface.slack.com/archives/C01NE71C4F7/p1749565278015529?thread_ts=1749031628.757929&cid=C01NE71C4F7
# Below is what produced the Hub PRs that work (loading without safetensors, saving the reloading)
model = FSMTForConditionalGeneration.from_pretrained(mname, use_safetensors=False)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
self.models_cache[mname] = FSMTForConditionalGeneration.from_pretrained(tmpdir).to(torch_device)
if torch_device == "cuda": if torch_device == "cuda":
self.models_cache[mname].half() self.models_cache[mname].half()
return self.models_cache[mname] return self.models_cache[mname]
@ -497,7 +506,7 @@ class FSMTModelIntegrationTests(unittest.TestCase):
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]] [[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]]
).to(torch_device) ).to(torch_device)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) torch.testing.assert_close(output[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
def translation_setup(self, pair): def translation_setup(self, pair):
text = { text = {
@ -512,6 +521,10 @@ class FSMTModelIntegrationTests(unittest.TestCase):
src_text = text[src] src_text = text[src]
tgt_text = text[tgt] tgt_text = text[tgt]
# To make `test_translation_pipeline_0_en_ru` pass in #38904. When translating it back to `en`, we get
# `Machine learning is fine, isn't it?`.
if (src, tgt) == ("en", "ru"):
tgt_text = "Машинное обучение - это прекрасно, не так ли?"
tokenizer = self.get_tokenizer(mname) tokenizer = self.get_tokenizer(mname)
model = self.get_model(mname) model = self.get_model(mname)