mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Roberta is ExecuTorch compatible (#34425)
* Roberta is ExecuTorch compatible * [run_slow] roberta --------- Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
9bee9ff5db
commit
cd277618d4
@ -16,7 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import RobertaConfig, is_torch_available
|
||||
from transformers import AutoTokenizer, RobertaConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -41,6 +41,7 @@ if is_torch_available():
|
||||
RobertaEmbeddings,
|
||||
create_position_ids_from_input_ids,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
|
||||
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
|
||||
|
||||
@ -576,3 +577,43 @@ class RobertaModelIntegrationTest(TestCasePlus):
|
||||
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
|
||||
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_export(self):
|
||||
if not is_torch_greater_or_equal_than_2_4:
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
||||
roberta_model = "FacebookAI/roberta-base"
|
||||
device = "cpu"
|
||||
attn_implementation = "sdpa"
|
||||
max_length = 512
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(roberta_model)
|
||||
inputs = tokenizer(
|
||||
"The goal of life is <mask>.",
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
model = RobertaForMaskedLM.from_pretrained(
|
||||
roberta_model,
|
||||
device_map=device,
|
||||
attn_implementation=attn_implementation,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
logits = model(**inputs).logits
|
||||
eager_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
|
||||
self.assertEqual(eager_predicted_mask.split(), ["happiness", "love", "peace", "freedom", "simplicity"])
|
||||
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
args=(inputs["input_ids"],),
|
||||
kwargs={"attention_mask": inputs["attention_mask"]},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
|
||||
exported_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
|
||||
self.assertEqual(eager_predicted_mask, exported_predicted_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user