mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
This commit is contained in:
parent
dc552b9b70
commit
9f7b2b2432
@ -3,7 +3,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from utils import load_json
|
||||
|
||||
|
||||
@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_model_download():
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
class TestAll(TestCasePlus):
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_model_download(self):
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
|
||||
@timeout_decorator.timeout(120)
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_train_mbart_cc25_enro_script(self):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
"--fp16_opt_level=O1": "",
|
||||
"$MAX_LEN": 128,
|
||||
"$BS": 4,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": data_dir,
|
||||
"facebook/mbart-large-cc25": MODEL_NAME,
|
||||
# Download is 120MB in previous test.
|
||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||
}
|
||||
|
||||
@timeout_decorator.timeout(120)
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_train_mbart_cc25_enro_script():
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
"--fp16_opt_level=O1": "",
|
||||
"$MAX_LEN": 128,
|
||||
"$BS": 4,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": data_dir,
|
||||
"facebook/mbart-large-cc25": MODEL_NAME,
|
||||
# Download is 120MB in previous test.
|
||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||
}
|
||||
# Clean up bash script
|
||||
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
# Clean up bash script
|
||||
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = tempfile.mkdtemp(prefix="output_mbart")
|
||||
bash_script = bash_script.replace("--fp16 ", "")
|
||||
testargs = (
|
||||
["finetune.py"]
|
||||
+ bash_script.split()
|
||||
+ [
|
||||
f"--output_dir={output_dir}",
|
||||
"--gpus=1",
|
||||
"--learning_rate=3e-1",
|
||||
"--warmup_steps=0",
|
||||
"--val_check_interval=1.0",
|
||||
"--tokenizer_name=facebook/mbart-large-en-ro",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
args.do_predict = False
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
model = main(args)
|
||||
|
||||
bash_script = bash_script.replace("--fp16 ", "")
|
||||
testargs = (
|
||||
["finetune.py"]
|
||||
+ bash_script.split()
|
||||
+ [
|
||||
f"--output_dir={output_dir}",
|
||||
"--gpus=1",
|
||||
"--learning_rate=3e-1",
|
||||
"--warmup_steps=0",
|
||||
"--val_check_interval=1.0",
|
||||
"--tokenizer_name=facebook/mbart-large-en-ro",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
args.do_predict = False
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
model = main(args)
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
assert (
|
||||
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
|
||||
) # +1 accounts for val_sanity_check
|
||||
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||
|
||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
|
||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
@timeout_decorator.timeout(600)
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_opus_mt_distill_script(self):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
"--fp16_opt_level=O1": "",
|
||||
"$MAX_LEN": 128,
|
||||
"$BS": 16,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": data_dir,
|
||||
"$m": "sshleifer/student_marian_en_ro_6_1",
|
||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||
}
|
||||
|
||||
# Clean up bash script
|
||||
bash_script = (
|
||||
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
||||
)
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
bash_script = bash_script.replace("--fp16 ", " ")
|
||||
|
||||
@timeout_decorator.timeout(600)
|
||||
@slow
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||
def test_opus_mt_distill_script():
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
"--fp16_opt_level=O1": "",
|
||||
"$MAX_LEN": 128,
|
||||
"$BS": 16,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": data_dir,
|
||||
"$m": "sshleifer/student_marian_en_ro_6_1",
|
||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||
}
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
bash_script = bash_script.replace("--fp16", "")
|
||||
epochs = 6
|
||||
testargs = (
|
||||
["distillation.py"]
|
||||
+ bash_script.split()
|
||||
+ [
|
||||
f"--output_dir={output_dir}",
|
||||
"--gpus=1",
|
||||
"--learning_rate=1e-3",
|
||||
f"--num_train_epochs={epochs}",
|
||||
"--warmup_steps=10",
|
||||
"--val_check_interval=1.0",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
args.do_predict = False
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
|
||||
# Clean up bash script
|
||||
bash_script = (
|
||||
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
||||
)
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
bash_script = bash_script.replace("--fp16 ", " ")
|
||||
model = distill_main(args)
|
||||
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||
bash_script = bash_script.replace("--fp16", "")
|
||||
epochs = 6
|
||||
testargs = (
|
||||
["distillation.py"]
|
||||
+ bash_script.split()
|
||||
+ [
|
||||
f"--output_dir={output_dir}",
|
||||
"--gpus=1",
|
||||
"--learning_rate=1e-3",
|
||||
f"--num_train_epochs={epochs}",
|
||||
"--warmup_steps=10",
|
||||
"--val_check_interval=1.0",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
args.do_predict = False
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
||||
|
||||
model = distill_main(args)
|
||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
|
||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
|
Binary file not shown.
Binary file not shown.
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@ -7,11 +6,12 @@ import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pack_dataset import pack_data_dir
|
||||
from parameterized import parameterized
|
||||
from save_len_file import save_len_file
|
||||
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||
|
||||
|
||||
@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Will be truncated
|
||||
assert max_len_source > max_src_len # Will be truncated
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_src_len,
|
||||
max_target_length=max_tgt_len, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
class TestAll(TestCasePlus):
|
||||
@parameterized.expand(
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["labels"].shape[1] == max_tgt_len
|
||||
if tok_name != MBART_TINY:
|
||||
continue
|
||||
# check language codes in correct place
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = LegacySeq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
tmp_dir = Path(make_test_data_dir())
|
||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||
new_paths = {x.name for x in save_dir.iterdir()}
|
||||
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||
assert len(packed_examples) < len(orig_examples)
|
||||
assert len(packed_examples) == 1
|
||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
||||
def test_dynamic_batch_size():
|
||||
if not FAIRSEQ_AVAILABLE:
|
||||
return
|
||||
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
|
||||
required_batch_size_multiple = 64
|
||||
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
||||
batch_sizes = [len(x) for x in batch_sampler]
|
||||
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
|
||||
assert sum(batch_sizes) == len(ds) # no dropped or added examples
|
||||
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
|
||||
failures = []
|
||||
num_src_per_batch = []
|
||||
for batch in data_loader:
|
||||
src_shape = batch["input_ids"].shape
|
||||
bs = src_shape[0]
|
||||
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
|
||||
num_src_tokens = np.product(batch["input_ids"].shape)
|
||||
num_src_per_batch.append(num_src_tokens)
|
||||
if num_src_tokens > (max_tokens * 1.1):
|
||||
failures.append(num_src_tokens)
|
||||
assert num_src_per_batch[0] == max(num_src_per_batch)
|
||||
if failures:
|
||||
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
||||
|
||||
|
||||
def test_sortish_sampler_reduces_padding():
|
||||
ds, _, tokenizer = _get_dataset(max_len=512)
|
||||
bs = 2
|
||||
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
||||
|
||||
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
|
||||
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
|
||||
|
||||
pad = tokenizer.pad_token_id
|
||||
|
||||
def count_pad_tokens(data_loader, k="input_ids"):
|
||||
return [batch[k].eq(pad).sum().item() for batch in data_loader]
|
||||
|
||||
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
|
||||
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
||||
assert len(sortish_dl) == len(naive_dl)
|
||||
|
||||
|
||||
def _get_dataset(n_obs=1000, max_len=128):
|
||||
if os.getenv("USE_REAL_DATA", False):
|
||||
data_dir = "examples/seq2seq/wmt_en_ro"
|
||||
max_tokens = max_len * 2 * 64
|
||||
if not Path(data_dir).joinpath("train.len").exists():
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
else:
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
max_tokens = max_len * 4
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=data_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_len,
|
||||
max_target_length=max_len,
|
||||
n_obs=n_obs,
|
||||
)
|
||||
return ds, max_tokens, tokenizer
|
||||
|
||||
|
||||
def test_distributed_sortish_sampler_splits_indices_between_procs():
|
||||
ds, max_tokens, tokenizer = _get_dataset()
|
||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||
assert ids1.intersection(ids2) == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_dataset_kwargs(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
if tok_name == MBART_TINY:
|
||||
@slow
|
||||
def test_seq2seq_dataset_truncation(self, tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Will be truncated
|
||||
assert max_len_source > max_src_len # Will be truncated
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=make_test_data_dir(),
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=4,
|
||||
max_target_length=8,
|
||||
src_lang="EN",
|
||||
tgt_lang="FR",
|
||||
max_source_length=max_src_len,
|
||||
max_target_length=max_tgt_len, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||
else:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["labels"].shape[1] == max_tgt_len
|
||||
if tok_name != MBART_TINY:
|
||||
continue
|
||||
# check language codes in correct place
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
break # No need to test every batch
|
||||
|
||||
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(self, tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = LegacySeq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
def test_pack_dataset(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||
new_paths = {x.name for x in save_dir.iterdir()}
|
||||
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||
assert len(packed_examples) < len(orig_examples)
|
||||
assert len(packed_examples) == 1
|
||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||
assert orig_paths == new_paths
|
||||
|
||||
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
||||
def test_dynamic_batch_size(self):
|
||||
if not FAIRSEQ_AVAILABLE:
|
||||
return
|
||||
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
|
||||
required_batch_size_multiple = 64
|
||||
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
||||
batch_sizes = [len(x) for x in batch_sampler]
|
||||
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
|
||||
assert sum(batch_sizes) == len(ds) # no dropped or added examples
|
||||
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
|
||||
failures = []
|
||||
num_src_per_batch = []
|
||||
for batch in data_loader:
|
||||
src_shape = batch["input_ids"].shape
|
||||
bs = src_shape[0]
|
||||
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
|
||||
num_src_tokens = np.product(batch["input_ids"].shape)
|
||||
num_src_per_batch.append(num_src_tokens)
|
||||
if num_src_tokens > (max_tokens * 1.1):
|
||||
failures.append(num_src_tokens)
|
||||
assert num_src_per_batch[0] == max(num_src_per_batch)
|
||||
if failures:
|
||||
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
||||
|
||||
def test_sortish_sampler_reduces_padding(self):
|
||||
ds, _, tokenizer = self._get_dataset(max_len=512)
|
||||
bs = 2
|
||||
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
||||
|
||||
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
|
||||
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
|
||||
|
||||
pad = tokenizer.pad_token_id
|
||||
|
||||
def count_pad_tokens(data_loader, k="input_ids"):
|
||||
return [batch[k].eq(pad).sum().item() for batch in data_loader]
|
||||
|
||||
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
|
||||
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
||||
assert len(sortish_dl) == len(naive_dl)
|
||||
|
||||
def _get_dataset(self, n_obs=1000, max_len=128):
|
||||
if os.getenv("USE_REAL_DATA", False):
|
||||
data_dir = "examples/seq2seq/wmt_en_ro"
|
||||
max_tokens = max_len * 2 * 64
|
||||
if not Path(data_dir).joinpath("train.len").exists():
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
else:
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
max_tokens = max_len * 4
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=data_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_len,
|
||||
max_target_length=max_len,
|
||||
n_obs=n_obs,
|
||||
)
|
||||
return ds, max_tokens, tokenizer
|
||||
|
||||
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
|
||||
ds, max_tokens, tokenizer = self._get_dataset()
|
||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||
assert ids1.intersection(ids2) == set()
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_dataset_kwargs(self, tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
if tok_name == MBART_TINY:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||
type_path="train",
|
||||
max_source_length=4,
|
||||
max_target_length=8,
|
||||
src_lang="EN",
|
||||
tgt_lang="FR",
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||
else:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||
type_path="train",
|
||||
max_source_length=4,
|
||||
max_target_length=8,
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||
|
@ -1,9 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
@ -15,72 +14,71 @@ set_seed(42)
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
def test_finetune_trainer():
|
||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
def test_finetune_trainer(self):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow():
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
# Check metrics
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
# Check metrics
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
argv = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--data_dir {data_dir}
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--n_train 8
|
||||
--n_val 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_predict
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--learning_rate 3e-4
|
||||
--warmup_steps 8
|
||||
--evaluate_during_training
|
||||
--predict_with_generate
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
--sortish_sampler
|
||||
--label_smoothing 0.1
|
||||
--adafactor
|
||||
--task translation
|
||||
--tgt_lang ro_RO
|
||||
--src_lang en_XX
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
||||
argv = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--data_dir {data_dir}
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--n_train 8
|
||||
--n_val 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_predict
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--learning_rate 3e-4
|
||||
--warmup_steps 8
|
||||
--evaluate_during_training
|
||||
--predict_with_generate
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
--sortish_sampler
|
||||
--label_smoothing 0.1
|
||||
--adafactor
|
||||
--task translation
|
||||
--tgt_lang ro_RO
|
||||
--src_lang en_XX
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
return output_dir
|
||||
return output_dir
|
||||
|
@ -3,7 +3,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -15,11 +14,12 @@ import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval_search import run_search
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
|
||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ CHEAP_ARGS = {
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
@ -88,6 +88,7 @@ CHEAP_ARGS = {
|
||||
"student_encoder_layers": 1,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
"overwrite_output_dir": False,
|
||||
}
|
||||
|
||||
|
||||
@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(**kwargs):
|
||||
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
||||
def make_test_data_dir(tmp_dir):
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
||||
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
class TestSummarizationDistiller(TestCasePlus):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
failures.append(m)
|
||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||
|
||||
@require_multigpu
|
||||
@unittest.skip("Broken at the moment")
|
||||
def test_multigpu(self):
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=True,
|
||||
)
|
||||
self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
self._test_distiller_cli(updates)
|
||||
@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp()
|
||||
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp() # XXX: not being cleaned up
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
out_path_new = tempfile.mkdtemp()
|
||||
out_path_new = self.get_auto_remove_tmp_dir()
|
||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||
|
||||
@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
return model
|
||||
|
||||
|
||||
def run_eval_tester(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(input_file_name, articles)
|
||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{input_file_name}
|
||||
{output_file_name}
|
||||
--score_path {score_path}
|
||||
--task {task}
|
||||
--num_beams 2
|
||||
--length_penalty 2.0
|
||||
""".split()
|
||||
class TestTheRest(TestCasePlus):
|
||||
def run_eval_tester(self, model):
|
||||
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(input_file_name, articles)
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{input_file_name}
|
||||
{output_file_name}
|
||||
--score_path {score_path}
|
||||
--task {task}
|
||||
--num_beams 2
|
||||
--length_penalty 2.0
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
# os.remove(Path(output_file_name))
|
||||
|
||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||
# extensive testing of functionality with multiple models as @slow separately
|
||||
def test_run_eval():
|
||||
run_eval_tester(T5_TINY)
|
||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||
# extensive testing of functionality with multiple models as @slow separately
|
||||
def test_run_eval(self):
|
||||
self.run_eval_tester(T5_TINY)
|
||||
|
||||
# any extra models should go into the list here - can be slow
|
||||
@parameterized.expand([BART_TINY, MBART_TINY])
|
||||
@slow
|
||||
def test_run_eval_slow(self, model):
|
||||
self.run_eval_tester(model)
|
||||
|
||||
# any extra models should go into the list here - can be slow
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
|
||||
def test_run_eval_slow(model):
|
||||
run_eval_tester(model)
|
||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||
@parameterized.expand([T5_TINY, MBART_TINY])
|
||||
@slow
|
||||
def test_run_eval_search(self, model):
|
||||
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
|
||||
text = {
|
||||
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||
"de": [
|
||||
"Maschinelles Lernen ist großartig, oder?",
|
||||
"Ich esse gerne Bananen",
|
||||
"Morgen ist wieder ein toller Tag!",
|
||||
],
|
||||
}
|
||||
|
||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
|
||||
def test_run_eval_search(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
tmp_dir = Path(self.get_auto_remove_tmp_dir())
|
||||
score_path = str(tmp_dir / "scores.json")
|
||||
reference_path = str(tmp_dir / "val.target")
|
||||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{str(input_file_name)}
|
||||
{str(output_file_name)}
|
||||
--score_path {score_path}
|
||||
--reference_path {reference_path}
|
||||
--task {task}
|
||||
""".split()
|
||||
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
|
||||
|
||||
text = {
|
||||
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||
"de": [
|
||||
"Maschinelles Lernen ist großartig, oder?",
|
||||
"Ich esse gerne Bananen",
|
||||
"Morgen ist wieder ein toller Tag!",
|
||||
],
|
||||
}
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||
un_expected_strings = ["Info"]
|
||||
if "translation" in task:
|
||||
expected_strings.append("bleu")
|
||||
else:
|
||||
expected_strings.extend(ROUGE_KEYS)
|
||||
for w in expected_strings:
|
||||
assert w in cs.out
|
||||
for w in un_expected_strings:
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
tmp_dir = Path(tempfile.mkdtemp())
|
||||
score_path = str(tmp_dir / "scores.json")
|
||||
reference_path = str(tmp_dir / "val.target")
|
||||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{str(input_file_name)}
|
||||
{str(output_file_name)}
|
||||
--score_path {score_path}
|
||||
--reference_path {reference_path}
|
||||
--task {task}
|
||||
""".split()
|
||||
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
|
||||
@parameterized.expand(
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||
)
|
||||
def test_finetune(self, model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||
un_expected_strings = ["Info"]
|
||||
if "translation" in task:
|
||||
expected_strings.append("bleu")
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
module = main(args)
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
elif model == FSMT_TINY:
|
||||
fsmt = module.model.model
|
||||
embed_pos = fsmt.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||
# check that embeds are not the same
|
||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||
else:
|
||||
expected_strings.extend(ROUGE_KEYS)
|
||||
for w in expected_strings:
|
||||
assert w in cs.out
|
||||
for w in un_expected_strings:
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||
assert isinstance(example_batch, dict)
|
||||
assert len(example_batch) >= 4
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
def test_finetune_extra_model_args(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
module = main(args)
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
elif model == FSMT_TINY:
|
||||
fsmt = module.model.model
|
||||
embed_pos = fsmt.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||
# check that embeds are not the same
|
||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||
assert isinstance(example_batch, dict)
|
||||
assert len(example_batch) >= 4
|
||||
|
||||
|
||||
def test_finetune_extra_model_args():
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# test models whose config includes the extra_model_args
|
||||
model = BART_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
args_d1[p] = 0.5
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
for p in extra_model_params:
|
||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||
|
||||
# test models whose config doesn't include the extra_model_args
|
||||
model = T5_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
args = argparse.Namespace(**args_d2)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
# test models whose config includes the extra_model_args
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
args_d1[p] = 0.5
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
for p in extra_model_params:
|
||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||
|
||||
# test models whose config doesn't include the extra_model_args
|
||||
model = T5_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
args = argparse.Namespace(**args_d2)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
model = main(args)
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
|
||||
def test_finetune_lr_schedulers():
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
def test_finetune_lr_schedulers(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
model = BART_TINY
|
||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# emulate finetune.py
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = {"--help": True}
|
||||
# emulate finetune.py
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = {"--help": True}
|
||||
|
||||
# --help test
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStdout() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "--help is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = lightning_base.arg_to_scheduler_metavar
|
||||
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||
# --help test
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStdout() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "--help is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = lightning_base.arg_to_scheduler_metavar
|
||||
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||
|
||||
# --lr_scheduler=non_existing_scheduler test
|
||||
unsupported_param = "non_existing_scheduler"
|
||||
args = {f"--lr_scheduler={unsupported_param}"}
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStderr() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "invalid argument is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = f"invalid choice: '{unsupported_param}'"
|
||||
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||
# --lr_scheduler=non_existing_scheduler test
|
||||
unsupported_param = "non_existing_scheduler"
|
||||
args = {f"--lr_scheduler={unsupported_param}"}
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStderr() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "invalid argument is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = f"invalid choice: '{unsupported_param}'"
|
||||
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||
|
||||
# --lr_scheduler=existing_scheduler test
|
||||
supported_param = "cosine"
|
||||
args_d1 = args_d.copy()
|
||||
args_d1["lr_scheduler"] = supported_param
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
||||
# --lr_scheduler=existing_scheduler test
|
||||
supported_param = "cosine"
|
||||
args_d1 = args_d.copy()
|
||||
args_d1["lr_scheduler"] = supported_param
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert (
|
||||
getattr(model.hparams, "lr_scheduler") == supported_param
|
||||
), f"lr_scheduler={supported_param} shouldn't fail"
|
||||
|
Loading…
Reference in New Issue
Block a user