mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tests] remove test_export_to_onnx
(#36241)
This commit is contained in:
parent
dae8708c36
commit
429f1a682d
@ -262,20 +262,6 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
model = FSMTModel(config).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(inputs_dict["input_ids"], inputs_dict["attention_mask"]),
|
||||
f"{tmpdirname}/fsmt_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=12,
|
||||
input_names=["input_ids", "attention_mask"],
|
||||
)
|
||||
|
||||
def test_ensure_weights_are_shared(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
|
@ -627,20 +627,6 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
model = LongT5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = LongT5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/longt5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -871,20 +871,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
model = MT5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = MT5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/t5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -26,7 +26,6 @@ from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.testing_utils import (
|
||||
require_essentia,
|
||||
require_librosa,
|
||||
require_onnx,
|
||||
require_scipy,
|
||||
require_torch,
|
||||
slow,
|
||||
@ -611,20 +610,6 @@ class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_onnx
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = Pop2PianoForConditionalGeneration(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/Pop2Piano_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=14,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_pass_with_input_features(self):
|
||||
input_features = BatchFeature(
|
||||
{
|
||||
|
@ -709,20 +709,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
model = SwitchTransformersModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = SwitchTransformersModel(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/switch_transformers_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -875,20 +875,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
model = T5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/t5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -525,20 +525,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = UMT5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/t5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||
def test_model_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user