[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)

This commit is contained in:
Stas Bekman 2020-10-17 11:33:21 -07:00 committed by GitHub
parent dc552b9b70
commit 9f7b2b2432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 635 additions and 655 deletions

View File

@ -3,7 +3,6 @@
import argparse import argparse
import os import os
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import load_json from utils import load_json
@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow class TestAll(TestCasePlus):
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_model_download(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" def test_model_download(self):
BartForConditionalGeneration.from_pretrained(MODEL_NAME) """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120)
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_train_mbart_cc25_enro_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 4,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME,
# Download is 120MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0",
}
@timeout_decorator.timeout(120) # Clean up bash script
@slow bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
def test_train_mbart_cc25_enro_script(): for k, v in env_vars_to_replace.items():
data_dir = "examples/seq2seq/test_data/wmt_en_ro" bash_script = bash_script.replace(k, str(v))
env_vars_to_replace = { output_dir = self.get_auto_remove_tmp_dir()
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 4,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME,
# Download is 120MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script bash_script = bash_script.replace("--fp16 ", "")
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() testargs = (
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") ["finetune.py"]
for k, v in env_vars_to_replace.items(): + bash_script.split()
bash_script = bash_script.replace(k, str(v)) + [
output_dir = tempfile.mkdtemp(prefix="output_mbart") f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=3e-1",
"--warmup_steps=0",
"--val_check_interval=1.0",
"--tokenizer_name=facebook/mbart-large-en-ro",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args)
bash_script = bash_script.replace("--fp16 ", "") # Check metrics
testargs = ( metrics = load_json(model.metrics_save_path)
["finetune.py"] first_step_stats = metrics["val"][0]
+ bash_script.split() last_step_stats = metrics["val"][-1]
+ [ assert (
f"--output_dir={output_dir}", len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
"--gpus=1", ) # +1 accounts for val_sanity_check
"--learning_rate=3e-1",
"--warmup_steps=0",
"--val_check_interval=1.0",
"--tokenizer_name=facebook/mbart-large-en-ro",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args)
# Check metrics assert last_step_stats["val_avg_gen_time"] >= 0.01
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing # check lightning ckpt can be loaded and has a reasonable statedict
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. contents = os.listdir(output_dir)
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# check lightning ckpt can be loaded and has a reasonable statedict # TODO: turn on args.do_predict when PL bug fixed.
contents = os.listdir(output_dir) if args.do_predict:
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] contents = {os.path.basename(p) for p in contents}
full_path = os.path.join(args.output_dir, ckpt_path) assert "test_generations.txt" in contents
ckpt = torch.load(full_path, map_location="cpu") assert "test_results.txt" in contents
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" # assert len(metrics["val"]) == desired_n_evals
assert expected_key in ckpt["state_dict"] assert len(metrics["test"]) == 1
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed. @timeout_decorator.timeout(600)
if args.do_predict: @slow
contents = {os.path.basename(p) for p in contents} @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
assert "test_generations.txt" in contents def test_opus_mt_distill_script(self):
assert "test_results.txt" in contents data_dir = "examples/seq2seq/test_data/wmt_en_ro"
# assert len(metrics["val"]) == desired_n_evals env_vars_to_replace = {
assert len(metrics["test"]) == 1 "--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 16,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"$m": "sshleifer/student_marian_en_ro_6_1",
"val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
@timeout_decorator.timeout(600) for k, v in env_vars_to_replace.items():
@slow bash_script = bash_script.replace(k, str(v))
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") output_dir = self.get_auto_remove_tmp_dir()
def test_opus_mt_distill_script(): bash_script = bash_script.replace("--fp16", "")
data_dir = "examples/seq2seq/test_data/wmt_en_ro" epochs = 6
env_vars_to_replace = { testargs = (
"--fp16_opt_level=O1": "", ["distillation.py"]
"$MAX_LEN": 128, + bash_script.split()
"$BS": 16, + [
"$GAS": 1, f"--output_dir={output_dir}",
"$ENRO_DIR": data_dir, "--gpus=1",
"$m": "sshleifer/student_marian_en_ro_6_1", "--learning_rate=1e-3",
"val_check_interval=0.25": "val_check_interval=1.0", f"--num_train_epochs={epochs}",
} "--warmup_steps=10",
"--val_check_interval=1.0",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
# Clean up bash script model = distill_main(args)
bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items(): # Check metrics
bash_script = bash_script.replace(k, str(v)) metrics = load_json(model.metrics_save_path)
output_dir = tempfile.mkdtemp(prefix="marian_output") first_step_stats = metrics["val"][0]
bash_script = bash_script.replace("--fp16", "") last_step_stats = metrics["val"][-1]
epochs = 6 assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
testargs = (
["distillation.py"]
+ bash_script.split()
+ [
f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=1e-3",
f"--num_train_epochs={epochs}",
"--warmup_steps=10",
"--val_check_interval=1.0",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = distill_main(args) assert last_step_stats["val_avg_gen_time"] >= 0.01
# Check metrics assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
metrics = load_json(model.metrics_save_path) assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
first_step_stats = metrics["val"][0] assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 # check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing # TODO: turn on args.do_predict when PL bug fixed.
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. if args.do_predict:
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
# check lightning ckpt can be loaded and has a reasonable statedict assert "test_results.txt" in contents
contents = os.listdir(output_dir) # assert len(metrics["val"]) == desired_n_evals
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] assert len(metrics["test"]) == 1
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed.
if args.do_predict:
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1

View File

@ -1,5 +1,4 @@
import os import os
import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -7,11 +6,12 @@ import pytest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir from pack_dataset import pack_data_dir
from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum" PEGASUS_XSUM = "google/pegasus-xsum"
@slow class TestAll(TestCasePlus):
@pytest.mark.parametrize( @parameterized.expand(
"tok_name", [
[ MBART_TINY,
MBART_TINY, MARIAN_TINY,
MARIAN_TINY, T5_TINY,
T5_TINY, BART_TINY,
BART_TINY, PEGASUS_XSUM,
PEGASUS_XSUM, ],
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=max_src_len,
max_target_length=max_tgt_len, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
) )
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) @slow
for batch in dataloader: def test_seq2seq_dataset_truncation(self, tok_name):
assert isinstance(batch, dict) tokenizer = AutoTokenizer.from_pretrained(tok_name)
assert batch["attention_mask"].shape == batch["input_ids"].shape tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
# show that articles were trimmed. max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
assert batch["input_ids"].shape[1] == max_src_len max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
# show that targets are the same len max_src_len = 4
assert batch["labels"].shape[1] == max_tgt_len max_tgt_len = 8
if tok_name != MBART_TINY: assert max_len_target > max_src_len # Will be truncated
continue assert max_len_source > max_src_len # Will be truncated
# check language codes in correct place src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir())
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size():
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding():
ds, _, tokenizer = _get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs():
ds, max_tokens, tokenizer = _get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@pytest.mark.parametrize(
"tok_name",
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset( train_dataset = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir=make_test_data_dir(), data_dir=tmp_dir,
type_path="train", type_path="train",
max_source_length=4, max_source_length=max_src_len,
max_target_length=8, max_target_length=max_tgt_len, # ignored
src_lang="EN", src_lang=src_lang,
tgt_lang="FR", tgt_lang=tgt_lang,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "src_lang" in kwargs and "tgt_lang" in kwargs for batch in dataloader:
else: assert isinstance(batch, dict)
train_dataset = Seq2SeqDataset( assert batch["attention_mask"].shape == batch["input_ids"].shape
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 # show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["labels"].shape[1] == max_tgt_len
if tok_name != MBART_TINY:
continue
# check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
break # No need to test every batch
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(self, tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
) )
kwargs = train_dataset.dataset_kwargs dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs for batch in dataloader:
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
def test_pack_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()}
packed_examples = save_dir.joinpath("train.source").open().readlines()
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
assert len(packed_examples) < len(orig_examples)
assert len(packed_examples) == 1
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
def test_dynamic_batch_size(self):
if not FAIRSEQ_AVAILABLE:
return
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler]
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
assert sum(batch_sizes) == len(ds) # no dropped or added examples
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
failures = []
num_src_per_batch = []
for batch in data_loader:
src_shape = batch["input_ids"].shape
bs = src_shape[0]
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
num_src_tokens = np.product(batch["input_ids"].shape)
num_src_per_batch.append(num_src_tokens)
if num_src_tokens > (max_tokens * 1.1):
failures.append(num_src_tokens)
assert num_src_per_batch[0] == max(num_src_per_batch)
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding(self):
ds, _, tokenizer = self._get_dataset(max_len=512)
bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
pad = tokenizer.pad_token_id
def count_pad_tokens(data_loader, k="input_ids"):
return [batch[k].eq(pad).sum().item() for batch in data_loader]
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl)
def _get_dataset(self, n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64
if not Path(data_dir).joinpath("train.len").exists():
save_len_file(MARIAN_TINY, data_dir)
else:
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
max_tokens = max_len * 4
save_len_file(MARIAN_TINY, data_dir)
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
ds = Seq2SeqDataset(
tokenizer,
data_dir=data_dir,
type_path="train",
max_source_length=max_len,
max_target_length=max_len,
n_obs=n_obs,
)
return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
ds, max_tokens, tokenizer = self._get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@parameterized.expand(
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
src_lang="EN",
tgt_lang="FR",
)
kwargs = train_dataset.dataset_kwargs
assert "src_lang" in kwargs and "tgt_lang" in kwargs
else:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
)
kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0

View File

@ -1,9 +1,8 @@
import os import os
import sys import sys
import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
@ -15,72 +14,71 @@ set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer(): class TestFinetuneTrainer(TestCasePlus):
output_dir = run_trainer(1, "12", MBART_TINY, 1) def test_finetune_trainer(self):
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
first_step_stats = eval_metrics[0] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
assert "eval_bleu" in first_step_stats first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
@slow # Check metrics
def test_finetune_trainer_slow(): logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
# There is a missing call to __init__process_group somewhere eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
# Check metrics assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history assert isinstance(last_step_stats["eval_bleu"], float)
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing # test if do_predict saves generations and metrics
assert isinstance(last_step_stats["eval_bleu"], float) contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
# test if do_predict saves generations and metrics def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
contents = os.listdir(output_dir) data_dir = "examples/seq2seq/test_data/wmt_en_ro"
contents = {os.path.basename(p) for p in contents} output_dir = self.get_auto_remove_tmp_dir()
assert "test_generations.txt" in contents argv = f"""
assert "test_results.json" in contents --model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): return output_dir
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="test_output")
argv = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate 3e-4
--warmup_steps 8
--evaluate_during_training
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--sortish_sampler
--label_smoothing 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
""".split()
# --eval_beams 2
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
return output_dir

View File

@ -3,7 +3,6 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
@ -15,11 +14,12 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from parameterized import parameterized
from run_eval import generate_summaries_or_translations, run_generate from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers": 1, "student_decoder_layers": 1,
"val_check_interval": 1.0, "val_check_interval": 1.0,
"output_dir": "", "output_dir": "",
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp "fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False, "no_teacher": False,
"fp16_opt_level": "O1", "fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0, "gpus": 1 if CUDA_AVAILABLE else 0,
@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers": 1, "student_encoder_layers": 1,
"freeze_encoder": False, "freeze_encoder": False,
"auto_scale_batch_size": False, "auto_scale_batch_size": False,
"overwrite_output_dir": False,
} }
@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(**kwargs): def make_test_data_dir(tmp_dir):
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
for split in ["train", "val", "test"]: for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES) _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES) _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir return tmp_dir
class TestSummarizationDistiller(unittest.TestCase): class TestSummarizationDistiller(TestCasePlus):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures.append(m) failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_multigpu
@unittest.skip("Broken at the moment")
def test_multigpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=True,
)
self._test_distiller_cli(updates, check_contents=False)
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertEqual(1, len(ckpts)) self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2) self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
out_path = tempfile.mktemp() out_path = tempfile.mktemp() # XXX: not being cleaned up
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr")) generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists()) self.assertTrue(Path(out_path).exists())
out_path_new = tempfile.mkdtemp() out_path_new = self.get_auto_remove_tmp_dir()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
) )
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = tempfile.mkdtemp(prefix="output_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d)) model = distill_main(argparse.Namespace(**args_d))
@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
return model return model
def run_eval_tester(model): class TestTheRest(TestCasePlus):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" def run_eval_tester(self, model):
output_file_name = input_file_name.parent / "utest_output.txt" input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
assert not output_file_name.exists() output_file_name = input_file_name.parent / "utest_output.txt"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] assert not output_file_name.exists()
_dump_articles(input_file_name, articles) articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
score_path = str(Path(tempfile.mkdtemp()) / "scores.json") _dump_articles(input_file_name, articles)
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = f"""
run_eval_search.py
{model}
{input_file_name}
{output_file_name}
--score_path {score_path}
--task {task}
--num_beams 2
--length_penalty 2.0
""".split()
with patch.object(sys, "argv", testargs): score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
run_generate() task = "translation_en_to_de" if model == T5_TINY else "summarization"
assert Path(output_file_name).exists() testargs = f"""
os.remove(Path(output_file_name)) run_eval_search.py
{model}
{input_file_name}
{output_file_name}
--score_path {score_path}
--task {task}
--num_beams 2
--length_penalty 2.0
""".split()
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
# os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an # test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately # extensive testing of functionality with multiple models as @slow separately
def test_run_eval(): def test_run_eval(self):
run_eval_tester(T5_TINY) self.run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow
@parameterized.expand([BART_TINY, MBART_TINY])
@slow
def test_run_eval_slow(self, model):
self.run_eval_tester(model)
# any extra models should go into the list here - can be slow # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow @parameterized.expand([T5_TINY, MBART_TINY])
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY]) @slow
def test_run_eval_slow(model): def test_run_eval_search(self, model):
run_eval_tester(model) input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
text = {
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
"de": [
"Maschinelles Lernen ist großartig, oder?",
"Ich esse gerne Bananen",
"Morgen ist wieder ein toller Tag!",
],
}
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) tmp_dir = Path(self.get_auto_remove_tmp_dir())
@slow score_path = str(tmp_dir / "scores.json")
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY]) reference_path = str(tmp_dir / "val.target")
def test_run_eval_search(model): _dump_articles(input_file_name, text["en"])
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" _dump_articles(reference_path, text["de"])
output_file_name = input_file_name.parent / "utest_output.txt" task = "translation_en_to_de" if model == T5_TINY else "summarization"
assert not output_file_name.exists() testargs = f"""
run_eval_search.py
{model}
{str(input_file_name)}
{str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path}
--task {task}
""".split()
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
text = { with patch.object(sys, "argv", testargs):
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"], with CaptureStdout() as cs:
"de": [ run_search()
"Maschinelles Lernen ist großartig, oder?", expected_strings = [" num_beams | length_penalty", model, "Best score args"]
"Ich esse gerne Bananen", un_expected_strings = ["Info"]
"Morgen ist wieder ein toller Tag!", if "translation" in task:
], expected_strings.append("bleu")
} else:
expected_strings.extend(ROUGE_KEYS)
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
assert w not in cs.out
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
tmp_dir = Path(tempfile.mkdtemp()) @parameterized.expand(
score_path = str(tmp_dir / "scores.json") [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
reference_path = str(tmp_dir / "val.target") )
_dump_articles(input_file_name, text["en"]) def test_finetune(self, model):
_dump_articles(reference_path, text["de"]) args_d: dict = CHEAP_ARGS.copy()
task = "translation_en_to_de" if model == T5_TINY else "summarization" task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
testargs = f""" args_d["label_smoothing"] = 0.1 if task == "translation" else 0
run_eval_search.py
{model}
{str(input_file_name)}
{str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path}
--task {task}
""".split()
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
with patch.object(sys, "argv", testargs): tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
with CaptureStdout() as cs: output_dir = self.get_auto_remove_tmp_dir()
run_search() args_d.update(
expected_strings = [" num_beams | length_penalty", model, "Best score args"] data_dir=tmp_dir,
un_expected_strings = ["Info"] model_name_or_path=model,
if "translation" in task: tokenizer_name=None,
expected_strings.append("bleu") train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
elif model == FSMT_TINY:
fsmt = module.model.model
embed_pos = fsmt.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not fsmt.decoder.embed_tokens.weight.requires_grad
# check that embeds are not the same
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else: else:
expected_strings.extend(ROUGE_KEYS) bart = module.model.model
for w in expected_strings: embed_pos = bart.decoder.embed_positions
assert w in cs.out assert not embed_pos.weight.requires_grad
for w in un_expected_strings: assert not bart.shared.weight.requires_grad
assert w not in cs.out # check that embeds are the same
assert Path(output_file_name).exists() assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
os.remove(Path(output_file_name)) assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
@pytest.mark.parametrize( def test_finetune_extra_model_args(self):
"model", args_d: dict = CHEAP_ARGS.copy()
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir() task = "summarization"
output_dir = tempfile.mkdtemp(prefix="output_") tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings() args_d.update(
assert not input_embeds.weight.requires_grad data_dir=tmp_dir,
if model == T5_TINY: tokenizer_name=None,
lm_head = module.model.lm_head train_batch_size=2,
assert not lm_head.weight.requires_grad eval_batch_size=2,
assert (lm_head.weight == input_embeds.weight).all().item() do_predict=False,
elif model == FSMT_TINY: task=task,
fsmt = module.model.model src_lang="en_XX",
embed_pos = fsmt.decoder.embed_positions tgt_lang="ro_RO",
assert not embed_pos.weight.requires_grad freeze_encoder=True,
assert not fsmt.decoder.embed_tokens.weight.requires_grad freeze_embeds=True,
# check that embeds are not the same )
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json") # test models whose config includes the extra_model_args
assert isinstance(example_batch, dict) model = BART_TINY
assert len(example_batch) >= 4 output_dir = self.get_auto_remove_tmp_dir()
args_d1 = args_d.copy()
args_d1.update(
def test_finetune_extra_model_args(): model_name_or_path=model,
args_d: dict = CHEAP_ARGS.copy() output_dir=output_dir,
)
task = "summarization" extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
tmp_dir = make_test_data_dir() for p in extra_model_params:
args_d1[p] = 0.5
args_d.update( args = argparse.Namespace(**args_d1)
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_")
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model,
output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = tempfile.mkdtemp(prefix="output_2_")
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model,
output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args) model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model,
output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_schedulers(): def test_finetune_lr_schedulers(self):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
task = "summarization" task = "summarization"
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
model = BART_TINY model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update( args_d.update(
data_dir=tmp_dir, data_dir=tmp_dir,
model_name_or_path=model, model_name_or_path=model,
output_dir=output_dir, output_dir=output_dir,
tokenizer_name=None, tokenizer_name=None,
train_batch_size=2, train_batch_size=2,
eval_batch_size=2, eval_batch_size=2,
do_predict=False, do_predict=False,
task=task, task=task,
src_lang="en_XX", src_lang="en_XX",
tgt_lang="ro_RO", tgt_lang="ro_RO",
freeze_encoder=True, freeze_encoder=True,
freeze_embeds=True, freeze_embeds=True,
) )
# emulate finetune.py # emulate finetune.py
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = {"--help": True} args = {"--help": True}
# --help test # --help test
with pytest.raises(SystemExit) as excinfo: with pytest.raises(SystemExit) as excinfo:
with CaptureStdout() as cs: with CaptureStdout() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "--help is expected to sys.exit" assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type == SystemExit
expected = lightning_base.arg_to_scheduler_metavar expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers" assert expected in cs.out, "--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test # --lr_scheduler=non_existing_scheduler test
unsupported_param = "non_existing_scheduler" unsupported_param = "non_existing_scheduler"
args = {f"--lr_scheduler={unsupported_param}"} args = {f"--lr_scheduler={unsupported_param}"}
with pytest.raises(SystemExit) as excinfo: with pytest.raises(SystemExit) as excinfo:
with CaptureStderr() as cs: with CaptureStderr() as cs:
args = parser.parse_args(args) args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit" assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit assert excinfo.type == SystemExit
expected = f"invalid choice: '{unsupported_param}'" expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
# --lr_scheduler=existing_scheduler test # --lr_scheduler=existing_scheduler test
supported_param = "cosine" supported_param = "cosine"
args_d1 = args_d.copy() args_d1 = args_d.copy()
args_d1["lr_scheduler"] = supported_param args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1) args = argparse.Namespace(**args_d1)
model = main(args) model = main(args)
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail" assert (
getattr(model.hparams, "lr_scheduler") == supported_param
), f"lr_scheduler={supported_param} shouldn't fail"