mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
transition to new tests dir (#10080)
This commit is contained in:
parent
84acf0c7bb
commit
781220acab
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@ -19,13 +20,17 @@ from transformers.integrations import is_deepspeed_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
from utils import load_json
|
||||
|
||||
|
||||
set_seed(42)
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
@ -122,7 +127,7 @@ class TestDeepSpeed(TestCasePlus):
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
||||
distributed_args = f"""
|
||||
{self.test_file_dir}/finetune_trainer.py
|
||||
{self.test_file_dir}/../../seq2seq/finetune_trainer.py
|
||||
""".split()
|
||||
cmd = ["deepspeed"] + distributed_args + args + ds_args
|
||||
# keep for quick debug
|
Loading…
Reference in New Issue
Block a user