mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix flaky ONNX tests (#6531)
This commit is contained in:
parent
39c3b1d9de
commit
b41cc0b86a
@ -1,7 +1,5 @@
|
||||
import unittest
|
||||
from os.path import dirname, exists
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
||||
@ -72,7 +70,7 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
def test_quantize_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "pt", 12)
|
||||
quantized_path = quantize(Path(path))
|
||||
quantized_path = quantize(path)
|
||||
|
||||
# Ensure the actual quantized model is not bigger than the original one
|
||||
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
||||
@ -82,16 +80,16 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
try:
|
||||
# Compute path
|
||||
with TemporaryDirectory() as tempdir:
|
||||
path = tempdir + "/model.onnx"
|
||||
path = Path(tempdir).joinpath("model.onnx")
|
||||
|
||||
# Remove folder if exists
|
||||
if exists(dirname(path)):
|
||||
rmtree(dirname(path))
|
||||
if path.parent.exists():
|
||||
path.parent.rmdir()
|
||||
|
||||
# Export
|
||||
convert(framework, model, path, opset, tokenizer)
|
||||
# Export
|
||||
convert(framework, model, path, opset, tokenizer)
|
||||
|
||||
return path
|
||||
return path
|
||||
except Exception as e:
|
||||
self.fail(e)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user