mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
parent
561b9a8c00
commit
99eb9b523f
@ -19,14 +19,14 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate.utils import write_basic_config
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
||||
from transformers.utils import is_apex_available
|
||||
|
||||
|
||||
@ -75,6 +75,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdir)
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_glue_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -94,12 +95,13 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_clm_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -120,12 +122,13 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||
return
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_mlm_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -139,12 +142,13 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_ner_no_trainer(self):
|
||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||
epochs = 7 if get_gpu_count() > 1 else 2
|
||||
@ -165,13 +169,14 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertLess(result["train_loss"], 0.5)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_squad_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -190,7 +195,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
|
||||
self.assertGreaterEqual(result["eval_f1"], 28)
|
||||
@ -198,6 +203,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_swag_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -214,12 +220,13 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_summarization_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -237,7 +244,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||
@ -247,6 +254,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_translation_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -268,7 +276,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--with_tracking
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
@ -292,10 +300,11 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--checkpointing_steps epoch
|
||||
""".split()
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
def test_run_image_classification_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -316,9 +325,9 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
|
||||
run_command(self._launch_args + testargs)
|
||||
result = get_results(tmp_dir)
|
||||
# The base model scores a 25%
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.625)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.6)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "step_1")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
|
||||
|
@ -20,6 +20,7 @@ import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -27,7 +28,7 @@ from collections.abc import Mapping
|
||||
from distutils.util import strtobool
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, List, Union
|
||||
from unittest import mock
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
@ -1561,3 +1562,25 @@ def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
# These utils relate to ensuring the right error message is received when running scripts
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def run_command(command: List[str], return_stdout=False):
|
||||
"""
|
||||
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||
if an error occured while running `command`
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
if return_stdout:
|
||||
if hasattr(output, "decode"):
|
||||
output = output.decode("utf-8")
|
||||
return output
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
|
Loading…
Reference in New Issue
Block a user