mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[examples tests] various fixes (#10584)
* fix sharded ddp enum * test fixes * stronger validation + apex breaks other tests
This commit is contained in:
parent
6f84531e61
commit
917f104502
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
@ -23,6 +24,7 @@ from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
slow,
|
||||
@ -65,13 +67,26 @@ def require_apex(test_case):
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
model_name=MBART_TINY,
|
||||
num_train_epochs=1,
|
||||
distributed=distributed,
|
||||
extra_args_str=extra_args_str,
|
||||
predict_with_generate=predict_with_generate,
|
||||
)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
|
||||
first_step_stats = eval_metrics[0]
|
||||
if predict_with_generate:
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
last_step_stats = eval_metrics[-1]
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
self.run_seq2seq_quick()
|
||||
@ -98,29 +113,47 @@ class TestTrainerExt(TestCasePlus):
|
||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||
|
||||
# test --sharded_ddp zero2 w/o --fp16
|
||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
@unittest.skip("XXX: Fixme: hanging")
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero2 w/ --fp16
|
||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
@unittest.skip("XXX: Fixme: hanging")
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
|
||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
||||
)
|
||||
|
||||
@require_apex
|
||||
@require_torch_gpu
|
||||
def test_run_seq2seq_apex(self):
|
||||
self.run_seq2seq_quick(extra_args_str="--fp16 --fp16_backend=apex")
|
||||
# XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
|
||||
# program and it breaks other tests that run from the same pytest worker, therefore until this is
|
||||
# sorted out it must be run only in an external program, that is distributed=True in this
|
||||
# test and only under one or more gpus - if we want cpu will need to make a special test
|
||||
#
|
||||
# specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
|
||||
# 2nd main() call it botches the future eval.
|
||||
#
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
# test 2nd time - was getting eval_loss': nan'
|
||||
# to reproduce the problem set distributed=False
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
|
||||
eval_steps=2,
|
||||
max_len=128,
|
||||
model_name=MARIAN_MODEL,
|
||||
learning_rate=3e-4,
|
||||
num_train_epochs=10,
|
||||
distributed=False,
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
@ -129,21 +162,22 @@ class TestTrainerExt(TestCasePlus):
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_preds_seq2seq.txt" in contents
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
def run_trainer(
|
||||
self,
|
||||
eval_steps: int,
|
||||
max_len: str,
|
||||
max_len: int,
|
||||
model_name: str,
|
||||
num_train_epochs: int,
|
||||
learning_rate: float = 3e-3,
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
predict_with_generate: bool = True,
|
||||
@ -168,7 +202,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--learning_rate 3e-3
|
||||
--learning_rate {learning_rate}
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--logging_steps 0
|
||||
|
@ -425,6 +425,6 @@ class TrainerMemoryTracker:
|
||||
|
||||
class ShardedDDPOption(ExplicitEnum):
|
||||
SIMPLE = "simple"
|
||||
ZERO_DP_2 = "zero2"
|
||||
ZERO_DP_3 = "zero3"
|
||||
ZERO_DP_2 = "zero_dp_2"
|
||||
ZERO_DP_3 = "zero_dp_3"
|
||||
OFFLOAD = "offload"
|
||||
|
Loading…
Reference in New Issue
Block a user