mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)
This commit is contained in:
parent
dc552b9b70
commit
9f7b2b2432
@ -3,7 +3,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
|
|||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from utils import load_json
|
from utils import load_json
|
||||||
|
|
||||||
|
|
||||||
@ -24,163 +23,164 @@ MODEL_NAME = MBART_TINY
|
|||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
class TestAll(TestCasePlus):
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
@slow
|
||||||
def test_model_download():
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
def test_model_download(self):
|
||||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||||
|
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(120)
|
||||||
|
@slow
|
||||||
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
|
def test_train_mbart_cc25_enro_script(self):
|
||||||
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
|
env_vars_to_replace = {
|
||||||
|
"--fp16_opt_level=O1": "",
|
||||||
|
"$MAX_LEN": 128,
|
||||||
|
"$BS": 4,
|
||||||
|
"$GAS": 1,
|
||||||
|
"$ENRO_DIR": data_dir,
|
||||||
|
"facebook/mbart-large-cc25": MODEL_NAME,
|
||||||
|
# Download is 120MB in previous test.
|
||||||
|
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||||
|
}
|
||||||
|
|
||||||
@timeout_decorator.timeout(120)
|
# Clean up bash script
|
||||||
@slow
|
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||||
def test_train_mbart_cc25_enro_script():
|
for k, v in env_vars_to_replace.items():
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
bash_script = bash_script.replace(k, str(v))
|
||||||
env_vars_to_replace = {
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
"--fp16_opt_level=O1": "",
|
|
||||||
"$MAX_LEN": 128,
|
|
||||||
"$BS": 4,
|
|
||||||
"$GAS": 1,
|
|
||||||
"$ENRO_DIR": data_dir,
|
|
||||||
"facebook/mbart-large-cc25": MODEL_NAME,
|
|
||||||
# Download is 120MB in previous test.
|
|
||||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Clean up bash script
|
bash_script = bash_script.replace("--fp16 ", "")
|
||||||
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
testargs = (
|
||||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
["finetune.py"]
|
||||||
for k, v in env_vars_to_replace.items():
|
+ bash_script.split()
|
||||||
bash_script = bash_script.replace(k, str(v))
|
+ [
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_mbart")
|
f"--output_dir={output_dir}",
|
||||||
|
"--gpus=1",
|
||||||
|
"--learning_rate=3e-1",
|
||||||
|
"--warmup_steps=0",
|
||||||
|
"--val_check_interval=1.0",
|
||||||
|
"--tokenizer_name=facebook/mbart-large-en-ro",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.do_predict = False
|
||||||
|
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||||
|
model = main(args)
|
||||||
|
|
||||||
bash_script = bash_script.replace("--fp16 ", "")
|
# Check metrics
|
||||||
testargs = (
|
metrics = load_json(model.metrics_save_path)
|
||||||
["finetune.py"]
|
first_step_stats = metrics["val"][0]
|
||||||
+ bash_script.split()
|
last_step_stats = metrics["val"][-1]
|
||||||
+ [
|
assert (
|
||||||
f"--output_dir={output_dir}",
|
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
|
||||||
"--gpus=1",
|
) # +1 accounts for val_sanity_check
|
||||||
"--learning_rate=3e-1",
|
|
||||||
"--warmup_steps=0",
|
|
||||||
"--val_check_interval=1.0",
|
|
||||||
"--tokenizer_name=facebook/mbart-large-en-ro",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
|
||||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.do_predict = False
|
|
||||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
|
||||||
model = main(args)
|
|
||||||
|
|
||||||
# Check metrics
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
metrics = load_json(model.metrics_save_path)
|
|
||||||
first_step_stats = metrics["val"][0]
|
|
||||||
last_step_stats = metrics["val"][-1]
|
|
||||||
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check
|
|
||||||
|
|
||||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||||
|
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||||
|
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||||
|
|
||||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
contents = os.listdir(output_dir)
|
||||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||||
|
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||||
|
ckpt = torch.load(full_path, map_location="cpu")
|
||||||
|
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||||
|
assert expected_key in ckpt["state_dict"]
|
||||||
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||||
|
|
||||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
# TODO: turn on args.do_predict when PL bug fixed.
|
||||||
contents = os.listdir(output_dir)
|
if args.do_predict:
|
||||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
contents = {os.path.basename(p) for p in contents}
|
||||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
assert "test_generations.txt" in contents
|
||||||
ckpt = torch.load(full_path, map_location="cpu")
|
assert "test_results.txt" in contents
|
||||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
# assert len(metrics["val"]) == desired_n_evals
|
||||||
assert expected_key in ckpt["state_dict"]
|
assert len(metrics["test"]) == 1
|
||||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
|
||||||
|
|
||||||
# TODO: turn on args.do_predict when PL bug fixed.
|
@timeout_decorator.timeout(600)
|
||||||
if args.do_predict:
|
@slow
|
||||||
contents = {os.path.basename(p) for p in contents}
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
||||||
assert "test_generations.txt" in contents
|
def test_opus_mt_distill_script(self):
|
||||||
assert "test_results.txt" in contents
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
# assert len(metrics["val"]) == desired_n_evals
|
env_vars_to_replace = {
|
||||||
assert len(metrics["test"]) == 1
|
"--fp16_opt_level=O1": "",
|
||||||
|
"$MAX_LEN": 128,
|
||||||
|
"$BS": 16,
|
||||||
|
"$GAS": 1,
|
||||||
|
"$ENRO_DIR": data_dir,
|
||||||
|
"$m": "sshleifer/student_marian_en_ro_6_1",
|
||||||
|
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean up bash script
|
||||||
|
bash_script = (
|
||||||
|
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
||||||
|
)
|
||||||
|
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||||
|
bash_script = bash_script.replace("--fp16 ", " ")
|
||||||
|
|
||||||
@timeout_decorator.timeout(600)
|
for k, v in env_vars_to_replace.items():
|
||||||
@slow
|
bash_script = bash_script.replace(k, str(v))
|
||||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
def test_opus_mt_distill_script():
|
bash_script = bash_script.replace("--fp16", "")
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
epochs = 6
|
||||||
env_vars_to_replace = {
|
testargs = (
|
||||||
"--fp16_opt_level=O1": "",
|
["distillation.py"]
|
||||||
"$MAX_LEN": 128,
|
+ bash_script.split()
|
||||||
"$BS": 16,
|
+ [
|
||||||
"$GAS": 1,
|
f"--output_dir={output_dir}",
|
||||||
"$ENRO_DIR": data_dir,
|
"--gpus=1",
|
||||||
"$m": "sshleifer/student_marian_en_ro_6_1",
|
"--learning_rate=1e-3",
|
||||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
f"--num_train_epochs={epochs}",
|
||||||
}
|
"--warmup_steps=10",
|
||||||
|
"--val_check_interval=1.0",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.do_predict = False
|
||||||
|
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||||
|
|
||||||
# Clean up bash script
|
model = distill_main(args)
|
||||||
bash_script = (
|
|
||||||
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
|
||||||
)
|
|
||||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
|
||||||
bash_script = bash_script.replace("--fp16 ", " ")
|
|
||||||
|
|
||||||
for k, v in env_vars_to_replace.items():
|
# Check metrics
|
||||||
bash_script = bash_script.replace(k, str(v))
|
metrics = load_json(model.metrics_save_path)
|
||||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
first_step_stats = metrics["val"][0]
|
||||||
bash_script = bash_script.replace("--fp16", "")
|
last_step_stats = metrics["val"][-1]
|
||||||
epochs = 6
|
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
||||||
testargs = (
|
|
||||||
["distillation.py"]
|
|
||||||
+ bash_script.split()
|
|
||||||
+ [
|
|
||||||
f"--output_dir={output_dir}",
|
|
||||||
"--gpus=1",
|
|
||||||
"--learning_rate=1e-3",
|
|
||||||
f"--num_train_epochs={epochs}",
|
|
||||||
"--warmup_steps=10",
|
|
||||||
"--val_check_interval=1.0",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
|
||||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.do_predict = False
|
|
||||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
|
||||||
|
|
||||||
model = distill_main(args)
|
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||||
|
|
||||||
# Check metrics
|
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||||
metrics = load_json(model.metrics_save_path)
|
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||||
first_step_stats = metrics["val"][0]
|
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||||
last_step_stats = metrics["val"][-1]
|
|
||||||
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
|
||||||
|
|
||||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||||
|
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||||
|
ckpt = torch.load(full_path, map_location="cpu")
|
||||||
|
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||||
|
assert expected_key in ckpt["state_dict"]
|
||||||
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||||
|
|
||||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
# TODO: turn on args.do_predict when PL bug fixed.
|
||||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
if args.do_predict:
|
||||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
assert "test_generations.txt" in contents
|
||||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
assert "test_results.txt" in contents
|
||||||
contents = os.listdir(output_dir)
|
# assert len(metrics["val"]) == desired_n_evals
|
||||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
assert len(metrics["test"]) == 1
|
||||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
|
||||||
ckpt = torch.load(full_path, map_location="cpu")
|
|
||||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
|
||||||
assert expected_key in ckpt["state_dict"]
|
|
||||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
|
||||||
|
|
||||||
# TODO: turn on args.do_predict when PL bug fixed.
|
|
||||||
if args.do_predict:
|
|
||||||
contents = {os.path.basename(p) for p in contents}
|
|
||||||
assert "test_generations.txt" in contents
|
|
||||||
assert "test_results.txt" in contents
|
|
||||||
# assert len(metrics["val"]) == desired_n_evals
|
|
||||||
assert len(metrics["test"]) == 1
|
|
||||||
|
Binary file not shown.
Binary file not shown.
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,11 +6,12 @@ import pytest
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pack_dataset import pack_data_dir
|
from pack_dataset import pack_data_dir
|
||||||
|
from parameterized import parameterized
|
||||||
from save_len_file import save_len_file
|
from save_len_file import save_len_file
|
||||||
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||||
|
|
||||||
|
|
||||||
@ -19,202 +19,198 @@ BERT_BASE_CASED = "bert-base-cased"
|
|||||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||||
|
|
||||||
|
|
||||||
@slow
|
class TestAll(TestCasePlus):
|
||||||
@pytest.mark.parametrize(
|
@parameterized.expand(
|
||||||
"tok_name",
|
[
|
||||||
[
|
MBART_TINY,
|
||||||
MBART_TINY,
|
MARIAN_TINY,
|
||||||
MARIAN_TINY,
|
T5_TINY,
|
||||||
T5_TINY,
|
BART_TINY,
|
||||||
BART_TINY,
|
PEGASUS_XSUM,
|
||||||
PEGASUS_XSUM,
|
],
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_seq2seq_dataset_truncation(tok_name):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
|
||||||
tmp_dir = make_test_data_dir()
|
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
|
||||||
max_src_len = 4
|
|
||||||
max_tgt_len = 8
|
|
||||||
assert max_len_target > max_src_len # Will be truncated
|
|
||||||
assert max_len_source > max_src_len # Will be truncated
|
|
||||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
|
||||||
train_dataset = Seq2SeqDataset(
|
|
||||||
tokenizer,
|
|
||||||
data_dir=tmp_dir,
|
|
||||||
type_path="train",
|
|
||||||
max_source_length=max_src_len,
|
|
||||||
max_target_length=max_tgt_len, # ignored
|
|
||||||
src_lang=src_lang,
|
|
||||||
tgt_lang=tgt_lang,
|
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
@slow
|
||||||
for batch in dataloader:
|
def test_seq2seq_dataset_truncation(self, tok_name):
|
||||||
assert isinstance(batch, dict)
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
# show that articles were trimmed.
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
assert batch["input_ids"].shape[1] == max_src_len
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
# show that targets are the same len
|
max_src_len = 4
|
||||||
assert batch["labels"].shape[1] == max_tgt_len
|
max_tgt_len = 8
|
||||||
if tok_name != MBART_TINY:
|
assert max_len_target > max_src_len # Will be truncated
|
||||||
continue
|
assert max_len_source > max_src_len # Will be truncated
|
||||||
# check language codes in correct place
|
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
|
||||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
|
||||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
|
||||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
|
||||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
|
||||||
|
|
||||||
break # No need to test every batch
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
|
||||||
def test_legacy_dataset_truncation(tok):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
|
||||||
tmp_dir = make_test_data_dir()
|
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
|
||||||
trunc_target = 4
|
|
||||||
train_dataset = LegacySeq2SeqDataset(
|
|
||||||
tokenizer,
|
|
||||||
data_dir=tmp_dir,
|
|
||||||
type_path="train",
|
|
||||||
max_source_length=20,
|
|
||||||
max_target_length=trunc_target,
|
|
||||||
)
|
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
|
||||||
for batch in dataloader:
|
|
||||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
|
||||||
# show that articles were trimmed.
|
|
||||||
assert batch["input_ids"].shape[1] == max_len_source
|
|
||||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
|
||||||
# show that targets were truncated
|
|
||||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
|
||||||
assert max_len_target > trunc_target # Truncated
|
|
||||||
break # No need to test every batch
|
|
||||||
|
|
||||||
|
|
||||||
def test_pack_dataset():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
|
||||||
|
|
||||||
tmp_dir = Path(make_test_data_dir())
|
|
||||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
|
||||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
|
||||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
|
||||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
|
||||||
new_paths = {x.name for x in save_dir.iterdir()}
|
|
||||||
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
|
||||||
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
|
||||||
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
|
||||||
assert len(packed_examples) < len(orig_examples)
|
|
||||||
assert len(packed_examples) == 1
|
|
||||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
|
||||||
assert orig_paths == new_paths
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
|
||||||
def test_dynamic_batch_size():
|
|
||||||
if not FAIRSEQ_AVAILABLE:
|
|
||||||
return
|
|
||||||
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
|
|
||||||
required_batch_size_multiple = 64
|
|
||||||
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
|
||||||
batch_sizes = [len(x) for x in batch_sampler]
|
|
||||||
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
|
|
||||||
assert sum(batch_sizes) == len(ds) # no dropped or added examples
|
|
||||||
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
|
|
||||||
failures = []
|
|
||||||
num_src_per_batch = []
|
|
||||||
for batch in data_loader:
|
|
||||||
src_shape = batch["input_ids"].shape
|
|
||||||
bs = src_shape[0]
|
|
||||||
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
|
|
||||||
num_src_tokens = np.product(batch["input_ids"].shape)
|
|
||||||
num_src_per_batch.append(num_src_tokens)
|
|
||||||
if num_src_tokens > (max_tokens * 1.1):
|
|
||||||
failures.append(num_src_tokens)
|
|
||||||
assert num_src_per_batch[0] == max(num_src_per_batch)
|
|
||||||
if failures:
|
|
||||||
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
|
||||||
|
|
||||||
|
|
||||||
def test_sortish_sampler_reduces_padding():
|
|
||||||
ds, _, tokenizer = _get_dataset(max_len=512)
|
|
||||||
bs = 2
|
|
||||||
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
|
||||||
|
|
||||||
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
|
|
||||||
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
|
|
||||||
|
|
||||||
pad = tokenizer.pad_token_id
|
|
||||||
|
|
||||||
def count_pad_tokens(data_loader, k="input_ids"):
|
|
||||||
return [batch[k].eq(pad).sum().item() for batch in data_loader]
|
|
||||||
|
|
||||||
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
|
|
||||||
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
|
||||||
assert len(sortish_dl) == len(naive_dl)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dataset(n_obs=1000, max_len=128):
|
|
||||||
if os.getenv("USE_REAL_DATA", False):
|
|
||||||
data_dir = "examples/seq2seq/wmt_en_ro"
|
|
||||||
max_tokens = max_len * 2 * 64
|
|
||||||
if not Path(data_dir).joinpath("train.len").exists():
|
|
||||||
save_len_file(MARIAN_TINY, data_dir)
|
|
||||||
else:
|
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
|
||||||
max_tokens = max_len * 4
|
|
||||||
save_len_file(MARIAN_TINY, data_dir)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
|
|
||||||
ds = Seq2SeqDataset(
|
|
||||||
tokenizer,
|
|
||||||
data_dir=data_dir,
|
|
||||||
type_path="train",
|
|
||||||
max_source_length=max_len,
|
|
||||||
max_target_length=max_len,
|
|
||||||
n_obs=n_obs,
|
|
||||||
)
|
|
||||||
return ds, max_tokens, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def test_distributed_sortish_sampler_splits_indices_between_procs():
|
|
||||||
ds, max_tokens, tokenizer = _get_dataset()
|
|
||||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
|
||||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
|
||||||
assert ids1.intersection(ids2) == set()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"tok_name",
|
|
||||||
[
|
|
||||||
MBART_TINY,
|
|
||||||
MARIAN_TINY,
|
|
||||||
T5_TINY,
|
|
||||||
BART_TINY,
|
|
||||||
PEGASUS_XSUM,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_dataset_kwargs(tok_name):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
|
||||||
if tok_name == MBART_TINY:
|
|
||||||
train_dataset = Seq2SeqDataset(
|
train_dataset = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=make_test_data_dir(),
|
data_dir=tmp_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
max_source_length=4,
|
max_source_length=max_src_len,
|
||||||
max_target_length=8,
|
max_target_length=max_tgt_len, # ignored
|
||||||
src_lang="EN",
|
src_lang=src_lang,
|
||||||
tgt_lang="FR",
|
tgt_lang=tgt_lang,
|
||||||
)
|
)
|
||||||
kwargs = train_dataset.dataset_kwargs
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
for batch in dataloader:
|
||||||
else:
|
assert isinstance(batch, dict)
|
||||||
train_dataset = Seq2SeqDataset(
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||||
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
# show that articles were trimmed.
|
||||||
|
assert batch["input_ids"].shape[1] == max_src_len
|
||||||
|
# show that targets are the same len
|
||||||
|
assert batch["labels"].shape[1] == max_tgt_len
|
||||||
|
if tok_name != MBART_TINY:
|
||||||
|
continue
|
||||||
|
# check language codes in correct place
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||||
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||||
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||||
|
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||||
|
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||||
|
|
||||||
|
break # No need to test every batch
|
||||||
|
|
||||||
|
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
|
||||||
|
def test_legacy_dataset_truncation(self, tok):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||||
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
|
trunc_target = 4
|
||||||
|
train_dataset = LegacySeq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=20,
|
||||||
|
max_target_length=trunc_target,
|
||||||
)
|
)
|
||||||
kwargs = train_dataset.dataset_kwargs
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
for batch in dataloader:
|
||||||
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||||
|
# show that articles were trimmed.
|
||||||
|
assert batch["input_ids"].shape[1] == max_len_source
|
||||||
|
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||||
|
# show that targets were truncated
|
||||||
|
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||||
|
assert max_len_target > trunc_target # Truncated
|
||||||
|
break # No need to test every batch
|
||||||
|
|
||||||
|
def test_pack_dataset(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||||
|
|
||||||
|
tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||||
|
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||||
|
save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
|
||||||
|
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||||
|
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||||
|
new_paths = {x.name for x in save_dir.iterdir()}
|
||||||
|
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||||
|
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||||
|
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||||
|
assert len(packed_examples) < len(orig_examples)
|
||||||
|
assert len(packed_examples) == 1
|
||||||
|
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||||
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
||||||
|
def test_dynamic_batch_size(self):
|
||||||
|
if not FAIRSEQ_AVAILABLE:
|
||||||
|
return
|
||||||
|
ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
|
||||||
|
required_batch_size_multiple = 64
|
||||||
|
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
||||||
|
batch_sizes = [len(x) for x in batch_sampler]
|
||||||
|
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
|
||||||
|
assert sum(batch_sizes) == len(ds) # no dropped or added examples
|
||||||
|
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
|
||||||
|
failures = []
|
||||||
|
num_src_per_batch = []
|
||||||
|
for batch in data_loader:
|
||||||
|
src_shape = batch["input_ids"].shape
|
||||||
|
bs = src_shape[0]
|
||||||
|
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
|
||||||
|
num_src_tokens = np.product(batch["input_ids"].shape)
|
||||||
|
num_src_per_batch.append(num_src_tokens)
|
||||||
|
if num_src_tokens > (max_tokens * 1.1):
|
||||||
|
failures.append(num_src_tokens)
|
||||||
|
assert num_src_per_batch[0] == max(num_src_per_batch)
|
||||||
|
if failures:
|
||||||
|
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
||||||
|
|
||||||
|
def test_sortish_sampler_reduces_padding(self):
|
||||||
|
ds, _, tokenizer = self._get_dataset(max_len=512)
|
||||||
|
bs = 2
|
||||||
|
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
||||||
|
|
||||||
|
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
|
||||||
|
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
|
||||||
|
|
||||||
|
pad = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def count_pad_tokens(data_loader, k="input_ids"):
|
||||||
|
return [batch[k].eq(pad).sum().item() for batch in data_loader]
|
||||||
|
|
||||||
|
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
|
||||||
|
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
||||||
|
assert len(sortish_dl) == len(naive_dl)
|
||||||
|
|
||||||
|
def _get_dataset(self, n_obs=1000, max_len=128):
|
||||||
|
if os.getenv("USE_REAL_DATA", False):
|
||||||
|
data_dir = "examples/seq2seq/wmt_en_ro"
|
||||||
|
max_tokens = max_len * 2 * 64
|
||||||
|
if not Path(data_dir).joinpath("train.len").exists():
|
||||||
|
save_len_file(MARIAN_TINY, data_dir)
|
||||||
|
else:
|
||||||
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
|
max_tokens = max_len * 4
|
||||||
|
save_len_file(MARIAN_TINY, data_dir)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
|
||||||
|
ds = Seq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=data_dir,
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=max_len,
|
||||||
|
max_target_length=max_len,
|
||||||
|
n_obs=n_obs,
|
||||||
|
)
|
||||||
|
return ds, max_tokens, tokenizer
|
||||||
|
|
||||||
|
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
|
||||||
|
ds, max_tokens, tokenizer = self._get_dataset()
|
||||||
|
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||||
|
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||||
|
assert ids1.intersection(ids2) == set()
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
MBART_TINY,
|
||||||
|
MARIAN_TINY,
|
||||||
|
T5_TINY,
|
||||||
|
BART_TINY,
|
||||||
|
PEGASUS_XSUM,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dataset_kwargs(self, tok_name):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
|
if tok_name == MBART_TINY:
|
||||||
|
train_dataset = Seq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=4,
|
||||||
|
max_target_length=8,
|
||||||
|
src_lang="EN",
|
||||||
|
tgt_lang="FR",
|
||||||
|
)
|
||||||
|
kwargs = train_dataset.dataset_kwargs
|
||||||
|
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||||
|
else:
|
||||||
|
train_dataset = Seq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=4,
|
||||||
|
max_target_length=8,
|
||||||
|
)
|
||||||
|
kwargs = train_dataset.dataset_kwargs
|
||||||
|
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||||
|
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
@ -15,72 +14,71 @@ set_seed(42)
|
|||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_trainer():
|
class TestFinetuneTrainer(TestCasePlus):
|
||||||
output_dir = run_trainer(1, "12", MBART_TINY, 1)
|
def test_finetune_trainer(self):
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
first_step_stats = eval_metrics[0]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
assert "eval_bleu" in first_step_stats
|
first_step_stats = eval_metrics[0]
|
||||||
|
assert "eval_bleu" in first_step_stats
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_finetune_trainer_slow(self):
|
||||||
|
# There is a missing call to __init__process_group somewhere
|
||||||
|
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
@slow
|
# Check metrics
|
||||||
def test_finetune_trainer_slow():
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
# There is a missing call to __init__process_group somewhere
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
first_step_stats = eval_metrics[0]
|
||||||
|
last_step_stats = eval_metrics[-1]
|
||||||
|
|
||||||
# Check metrics
|
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
|
||||||
first_step_stats = eval_metrics[0]
|
|
||||||
last_step_stats = eval_metrics[-1]
|
|
||||||
|
|
||||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
# test if do_predict saves generations and metrics
|
||||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
contents = os.listdir(output_dir)
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
assert "test_generations.txt" in contents
|
||||||
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
# test if do_predict saves generations and metrics
|
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
contents = os.listdir(output_dir)
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
contents = {os.path.basename(p) for p in contents}
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
assert "test_generations.txt" in contents
|
argv = f"""
|
||||||
assert "test_results.json" in contents
|
--model_name_or_path {model_name}
|
||||||
|
--data_dir {data_dir}
|
||||||
|
--output_dir {output_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--n_train 8
|
||||||
|
--n_val 8
|
||||||
|
--max_source_length {max_len}
|
||||||
|
--max_target_length {max_len}
|
||||||
|
--val_max_target_length {max_len}
|
||||||
|
--do_train
|
||||||
|
--do_eval
|
||||||
|
--do_predict
|
||||||
|
--num_train_epochs {str(num_train_epochs)}
|
||||||
|
--per_device_train_batch_size 4
|
||||||
|
--per_device_eval_batch_size 4
|
||||||
|
--learning_rate 3e-4
|
||||||
|
--warmup_steps 8
|
||||||
|
--evaluate_during_training
|
||||||
|
--predict_with_generate
|
||||||
|
--logging_steps 0
|
||||||
|
--save_steps {str(eval_steps)}
|
||||||
|
--eval_steps {str(eval_steps)}
|
||||||
|
--sortish_sampler
|
||||||
|
--label_smoothing 0.1
|
||||||
|
--adafactor
|
||||||
|
--task translation
|
||||||
|
--tgt_lang ro_RO
|
||||||
|
--src_lang en_XX
|
||||||
|
""".split()
|
||||||
|
# --eval_beams 2
|
||||||
|
|
||||||
|
testargs = ["finetune_trainer.py"] + argv
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
main()
|
||||||
|
|
||||||
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
return output_dir
|
||||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="test_output")
|
|
||||||
argv = f"""
|
|
||||||
--model_name_or_path {model_name}
|
|
||||||
--data_dir {data_dir}
|
|
||||||
--output_dir {output_dir}
|
|
||||||
--overwrite_output_dir
|
|
||||||
--n_train 8
|
|
||||||
--n_val 8
|
|
||||||
--max_source_length {max_len}
|
|
||||||
--max_target_length {max_len}
|
|
||||||
--val_max_target_length {max_len}
|
|
||||||
--do_train
|
|
||||||
--do_eval
|
|
||||||
--do_predict
|
|
||||||
--num_train_epochs {str(num_train_epochs)}
|
|
||||||
--per_device_train_batch_size 4
|
|
||||||
--per_device_eval_batch_size 4
|
|
||||||
--learning_rate 3e-4
|
|
||||||
--warmup_steps 8
|
|
||||||
--evaluate_during_training
|
|
||||||
--predict_with_generate
|
|
||||||
--logging_steps 0
|
|
||||||
--save_steps {str(eval_steps)}
|
|
||||||
--eval_steps {str(eval_steps)}
|
|
||||||
--sortish_sampler
|
|
||||||
--label_smoothing 0.1
|
|
||||||
--adafactor
|
|
||||||
--task translation
|
|
||||||
--tgt_lang ro_RO
|
|
||||||
--src_lang en_XX
|
|
||||||
""".split()
|
|
||||||
# --eval_beams 2
|
|
||||||
|
|
||||||
testargs = ["finetune_trainer.py"] + argv
|
|
||||||
with patch.object(sys, "argv", testargs):
|
|
||||||
main()
|
|
||||||
|
|
||||||
return output_dir
|
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -15,11 +14,12 @@ import lightning_base
|
|||||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
from distillation import distill_main
|
from distillation import distill_main
|
||||||
from finetune import SummarizationModule, main
|
from finetune import SummarizationModule, main
|
||||||
|
from parameterized import parameterized
|
||||||
from run_eval import generate_summaries_or_translations, run_generate
|
from run_eval import generate_summaries_or_translations, run_generate
|
||||||
from run_eval_search import run_search
|
from run_eval_search import run_search
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||||
from transformers.hf_api import HfApi
|
from transformers.hf_api import HfApi
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
|
||||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ CHEAP_ARGS = {
|
|||||||
"student_decoder_layers": 1,
|
"student_decoder_layers": 1,
|
||||||
"val_check_interval": 1.0,
|
"val_check_interval": 1.0,
|
||||||
"output_dir": "",
|
"output_dir": "",
|
||||||
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||||
"no_teacher": False,
|
"no_teacher": False,
|
||||||
"fp16_opt_level": "O1",
|
"fp16_opt_level": "O1",
|
||||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||||
@ -88,6 +88,7 @@ CHEAP_ARGS = {
|
|||||||
"student_encoder_layers": 1,
|
"student_encoder_layers": 1,
|
||||||
"freeze_encoder": False,
|
"freeze_encoder": False,
|
||||||
"auto_scale_batch_size": False,
|
"auto_scale_batch_size": False,
|
||||||
|
"overwrite_output_dir": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
|
|||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
|
|
||||||
|
|
||||||
def make_test_data_dir(**kwargs):
|
def make_test_data_dir(tmp_dir):
|
||||||
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
|
||||||
for split in ["train", "val", "test"]:
|
for split in ["train", "val", "test"]:
|
||||||
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||||
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||||
return tmp_dir
|
return tmp_dir
|
||||||
|
|
||||||
|
|
||||||
class TestSummarizationDistiller(unittest.TestCase):
|
class TestSummarizationDistiller(TestCasePlus):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
failures.append(m)
|
failures.append(m)
|
||||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||||
|
|
||||||
@require_multigpu
|
|
||||||
@unittest.skip("Broken at the moment")
|
|
||||||
def test_multigpu(self):
|
|
||||||
updates = dict(
|
|
||||||
no_teacher=True,
|
|
||||||
freeze_encoder=True,
|
|
||||||
gpus=2,
|
|
||||||
sortish_sampler=True,
|
|
||||||
)
|
|
||||||
self._test_distiller_cli(updates, check_contents=False)
|
|
||||||
|
|
||||||
def test_distill_no_teacher(self):
|
def test_distill_no_teacher(self):
|
||||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||||
self._test_distiller_cli(updates)
|
self._test_distiller_cli(updates)
|
||||||
@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self.assertEqual(1, len(ckpts))
|
self.assertEqual(1, len(ckpts))
|
||||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||||
self.assertEqual(len(transformer_ckpts), 2)
|
self.assertEqual(len(transformer_ckpts), 2)
|
||||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
|
||||||
out_path = tempfile.mktemp()
|
out_path = tempfile.mktemp() # XXX: not being cleaned up
|
||||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||||
self.assertTrue(Path(out_path).exists())
|
self.assertTrue(Path(out_path).exists())
|
||||||
|
|
||||||
out_path_new = tempfile.mkdtemp()
|
out_path_new = self.get_auto_remove_tmp_dir()
|
||||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||||
|
|
||||||
@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
default_updates.update(updates)
|
default_updates.update(updates)
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||||
model = distill_main(argparse.Namespace(**args_d))
|
model = distill_main(argparse.Namespace(**args_d))
|
||||||
@ -279,256 +268,253 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def run_eval_tester(model):
|
class TestTheRest(TestCasePlus):
|
||||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
def run_eval_tester(self, model):
|
||||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||||
assert not output_file_name.exists()
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
assert not output_file_name.exists()
|
||||||
_dump_articles(input_file_name, articles)
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
_dump_articles(input_file_name, articles)
|
||||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
|
||||||
testargs = f"""
|
|
||||||
run_eval_search.py
|
|
||||||
{model}
|
|
||||||
{input_file_name}
|
|
||||||
{output_file_name}
|
|
||||||
--score_path {score_path}
|
|
||||||
--task {task}
|
|
||||||
--num_beams 2
|
|
||||||
--length_penalty 2.0
|
|
||||||
""".split()
|
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
|
||||||
run_generate()
|
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||||
assert Path(output_file_name).exists()
|
testargs = f"""
|
||||||
os.remove(Path(output_file_name))
|
run_eval_search.py
|
||||||
|
{model}
|
||||||
|
{input_file_name}
|
||||||
|
{output_file_name}
|
||||||
|
--score_path {score_path}
|
||||||
|
--task {task}
|
||||||
|
--num_beams 2
|
||||||
|
--length_penalty 2.0
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_generate()
|
||||||
|
assert Path(output_file_name).exists()
|
||||||
|
# os.remove(Path(output_file_name))
|
||||||
|
|
||||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||||
# extensive testing of functionality with multiple models as @slow separately
|
# extensive testing of functionality with multiple models as @slow separately
|
||||||
def test_run_eval():
|
def test_run_eval(self):
|
||||||
run_eval_tester(T5_TINY)
|
self.run_eval_tester(T5_TINY)
|
||||||
|
|
||||||
|
# any extra models should go into the list here - can be slow
|
||||||
|
@parameterized.expand([BART_TINY, MBART_TINY])
|
||||||
|
@slow
|
||||||
|
def test_run_eval_slow(self, model):
|
||||||
|
self.run_eval_tester(model)
|
||||||
|
|
||||||
# any extra models should go into the list here - can be slow
|
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||||
@slow
|
@parameterized.expand([T5_TINY, MBART_TINY])
|
||||||
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
|
@slow
|
||||||
def test_run_eval_slow(model):
|
def test_run_eval_search(self, model):
|
||||||
run_eval_tester(model)
|
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||||
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
|
assert not output_file_name.exists()
|
||||||
|
|
||||||
|
text = {
|
||||||
|
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
||||||
|
"de": [
|
||||||
|
"Maschinelles Lernen ist großartig, oder?",
|
||||||
|
"Ich esse gerne Bananen",
|
||||||
|
"Morgen ist wieder ein toller Tag!",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
tmp_dir = Path(self.get_auto_remove_tmp_dir())
|
||||||
@slow
|
score_path = str(tmp_dir / "scores.json")
|
||||||
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
|
reference_path = str(tmp_dir / "val.target")
|
||||||
def test_run_eval_search(model):
|
_dump_articles(input_file_name, text["en"])
|
||||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
_dump_articles(reference_path, text["de"])
|
||||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||||
assert not output_file_name.exists()
|
testargs = f"""
|
||||||
|
run_eval_search.py
|
||||||
|
{model}
|
||||||
|
{str(input_file_name)}
|
||||||
|
{str(output_file_name)}
|
||||||
|
--score_path {score_path}
|
||||||
|
--reference_path {reference_path}
|
||||||
|
--task {task}
|
||||||
|
""".split()
|
||||||
|
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
|
||||||
|
|
||||||
text = {
|
with patch.object(sys, "argv", testargs):
|
||||||
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
|
with CaptureStdout() as cs:
|
||||||
"de": [
|
run_search()
|
||||||
"Maschinelles Lernen ist großartig, oder?",
|
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
||||||
"Ich esse gerne Bananen",
|
un_expected_strings = ["Info"]
|
||||||
"Morgen ist wieder ein toller Tag!",
|
if "translation" in task:
|
||||||
],
|
expected_strings.append("bleu")
|
||||||
}
|
else:
|
||||||
|
expected_strings.extend(ROUGE_KEYS)
|
||||||
|
for w in expected_strings:
|
||||||
|
assert w in cs.out
|
||||||
|
for w in un_expected_strings:
|
||||||
|
assert w not in cs.out
|
||||||
|
assert Path(output_file_name).exists()
|
||||||
|
os.remove(Path(output_file_name))
|
||||||
|
|
||||||
tmp_dir = Path(tempfile.mkdtemp())
|
@parameterized.expand(
|
||||||
score_path = str(tmp_dir / "scores.json")
|
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||||
reference_path = str(tmp_dir / "val.target")
|
)
|
||||||
_dump_articles(input_file_name, text["en"])
|
def test_finetune(self, model):
|
||||||
_dump_articles(reference_path, text["de"])
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||||
testargs = f"""
|
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||||
run_eval_search.py
|
|
||||||
{model}
|
|
||||||
{str(input_file_name)}
|
|
||||||
{str(output_file_name)}
|
|
||||||
--score_path {score_path}
|
|
||||||
--reference_path {reference_path}
|
|
||||||
--task {task}
|
|
||||||
""".split()
|
|
||||||
testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])
|
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
with CaptureStdout() as cs:
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
run_search()
|
args_d.update(
|
||||||
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
|
data_dir=tmp_dir,
|
||||||
un_expected_strings = ["Info"]
|
model_name_or_path=model,
|
||||||
if "translation" in task:
|
tokenizer_name=None,
|
||||||
expected_strings.append("bleu")
|
train_batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
output_dir=output_dir,
|
||||||
|
do_predict=True,
|
||||||
|
task=task,
|
||||||
|
src_lang="en_XX",
|
||||||
|
tgt_lang="ro_RO",
|
||||||
|
freeze_encoder=True,
|
||||||
|
freeze_embeds=True,
|
||||||
|
)
|
||||||
|
assert "n_train" in args_d
|
||||||
|
args = argparse.Namespace(**args_d)
|
||||||
|
module = main(args)
|
||||||
|
|
||||||
|
input_embeds = module.model.get_input_embeddings()
|
||||||
|
assert not input_embeds.weight.requires_grad
|
||||||
|
if model == T5_TINY:
|
||||||
|
lm_head = module.model.lm_head
|
||||||
|
assert not lm_head.weight.requires_grad
|
||||||
|
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||||
|
elif model == FSMT_TINY:
|
||||||
|
fsmt = module.model.model
|
||||||
|
embed_pos = fsmt.decoder.embed_positions
|
||||||
|
assert not embed_pos.weight.requires_grad
|
||||||
|
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||||
|
# check that embeds are not the same
|
||||||
|
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||||
else:
|
else:
|
||||||
expected_strings.extend(ROUGE_KEYS)
|
bart = module.model.model
|
||||||
for w in expected_strings:
|
embed_pos = bart.decoder.embed_positions
|
||||||
assert w in cs.out
|
assert not embed_pos.weight.requires_grad
|
||||||
for w in un_expected_strings:
|
assert not bart.shared.weight.requires_grad
|
||||||
assert w not in cs.out
|
# check that embeds are the same
|
||||||
assert Path(output_file_name).exists()
|
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||||
os.remove(Path(output_file_name))
|
assert bart.decoder.embed_tokens == bart.shared
|
||||||
|
|
||||||
|
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||||
|
assert isinstance(example_batch, dict)
|
||||||
|
assert len(example_batch) >= 4
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_finetune_extra_model_args(self):
|
||||||
"model",
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
|
||||||
)
|
|
||||||
def test_finetune(model):
|
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
|
||||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
|
||||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
task = "summarization"
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
args_d.update(
|
|
||||||
data_dir=tmp_dir,
|
|
||||||
model_name_or_path=model,
|
|
||||||
tokenizer_name=None,
|
|
||||||
train_batch_size=2,
|
|
||||||
eval_batch_size=2,
|
|
||||||
output_dir=output_dir,
|
|
||||||
do_predict=True,
|
|
||||||
task=task,
|
|
||||||
src_lang="en_XX",
|
|
||||||
tgt_lang="ro_RO",
|
|
||||||
freeze_encoder=True,
|
|
||||||
freeze_embeds=True,
|
|
||||||
)
|
|
||||||
assert "n_train" in args_d
|
|
||||||
args = argparse.Namespace(**args_d)
|
|
||||||
module = main(args)
|
|
||||||
|
|
||||||
input_embeds = module.model.get_input_embeddings()
|
args_d.update(
|
||||||
assert not input_embeds.weight.requires_grad
|
data_dir=tmp_dir,
|
||||||
if model == T5_TINY:
|
tokenizer_name=None,
|
||||||
lm_head = module.model.lm_head
|
train_batch_size=2,
|
||||||
assert not lm_head.weight.requires_grad
|
eval_batch_size=2,
|
||||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
do_predict=False,
|
||||||
elif model == FSMT_TINY:
|
task=task,
|
||||||
fsmt = module.model.model
|
src_lang="en_XX",
|
||||||
embed_pos = fsmt.decoder.embed_positions
|
tgt_lang="ro_RO",
|
||||||
assert not embed_pos.weight.requires_grad
|
freeze_encoder=True,
|
||||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
freeze_embeds=True,
|
||||||
# check that embeds are not the same
|
)
|
||||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
|
||||||
else:
|
|
||||||
bart = module.model.model
|
|
||||||
embed_pos = bart.decoder.embed_positions
|
|
||||||
assert not embed_pos.weight.requires_grad
|
|
||||||
assert not bart.shared.weight.requires_grad
|
|
||||||
# check that embeds are the same
|
|
||||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
|
||||||
assert bart.decoder.embed_tokens == bart.shared
|
|
||||||
|
|
||||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
# test models whose config includes the extra_model_args
|
||||||
assert isinstance(example_batch, dict)
|
model = BART_TINY
|
||||||
assert len(example_batch) >= 4
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args_d1 = args_d.copy()
|
||||||
|
args_d1.update(
|
||||||
def test_finetune_extra_model_args():
|
model_name_or_path=model,
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
task = "summarization"
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||||
tmp_dir = make_test_data_dir()
|
for p in extra_model_params:
|
||||||
|
args_d1[p] = 0.5
|
||||||
args_d.update(
|
args = argparse.Namespace(**args_d1)
|
||||||
data_dir=tmp_dir,
|
|
||||||
tokenizer_name=None,
|
|
||||||
train_batch_size=2,
|
|
||||||
eval_batch_size=2,
|
|
||||||
do_predict=False,
|
|
||||||
task=task,
|
|
||||||
src_lang="en_XX",
|
|
||||||
tgt_lang="ro_RO",
|
|
||||||
freeze_encoder=True,
|
|
||||||
freeze_embeds=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# test models whose config includes the extra_model_args
|
|
||||||
model = BART_TINY
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
|
||||||
args_d1 = args_d.copy()
|
|
||||||
args_d1.update(
|
|
||||||
model_name_or_path=model,
|
|
||||||
output_dir=output_dir,
|
|
||||||
)
|
|
||||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
|
||||||
for p in extra_model_params:
|
|
||||||
args_d1[p] = 0.5
|
|
||||||
args = argparse.Namespace(**args_d1)
|
|
||||||
model = main(args)
|
|
||||||
for p in extra_model_params:
|
|
||||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
|
||||||
|
|
||||||
# test models whose config doesn't include the extra_model_args
|
|
||||||
model = T5_TINY
|
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
|
||||||
args_d2 = args_d.copy()
|
|
||||||
args_d2.update(
|
|
||||||
model_name_or_path=model,
|
|
||||||
output_dir=output_dir,
|
|
||||||
)
|
|
||||||
unsupported_param = "encoder_layerdrop"
|
|
||||||
args_d2[unsupported_param] = 0.5
|
|
||||||
args = argparse.Namespace(**args_d2)
|
|
||||||
with pytest.raises(Exception) as excinfo:
|
|
||||||
model = main(args)
|
model = main(args)
|
||||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
for p in extra_model_params:
|
||||||
|
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||||
|
|
||||||
|
# test models whose config doesn't include the extra_model_args
|
||||||
|
model = T5_TINY
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args_d2 = args_d.copy()
|
||||||
|
args_d2.update(
|
||||||
|
model_name_or_path=model,
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
unsupported_param = "encoder_layerdrop"
|
||||||
|
args_d2[unsupported_param] = 0.5
|
||||||
|
args = argparse.Namespace(**args_d2)
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
model = main(args)
|
||||||
|
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||||
|
|
||||||
def test_finetune_lr_schedulers():
|
def test_finetune_lr_schedulers(self):
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
|
|
||||||
task = "summarization"
|
task = "summarization"
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||||
|
|
||||||
model = BART_TINY
|
model = BART_TINY
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
args_d.update(
|
args_d.update(
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
model_name_or_path=model,
|
model_name_or_path=model,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
train_batch_size=2,
|
train_batch_size=2,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
do_predict=False,
|
do_predict=False,
|
||||||
task=task,
|
task=task,
|
||||||
src_lang="en_XX",
|
src_lang="en_XX",
|
||||||
tgt_lang="ro_RO",
|
tgt_lang="ro_RO",
|
||||||
freeze_encoder=True,
|
freeze_encoder=True,
|
||||||
freeze_embeds=True,
|
freeze_embeds=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# emulate finetune.py
|
# emulate finetune.py
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||||
args = {"--help": True}
|
args = {"--help": True}
|
||||||
|
|
||||||
# --help test
|
# --help test
|
||||||
with pytest.raises(SystemExit) as excinfo:
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
with CaptureStdout() as cs:
|
with CaptureStdout() as cs:
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(args)
|
||||||
assert False, "--help is expected to sys.exit"
|
assert False, "--help is expected to sys.exit"
|
||||||
assert excinfo.type == SystemExit
|
assert excinfo.type == SystemExit
|
||||||
expected = lightning_base.arg_to_scheduler_metavar
|
expected = lightning_base.arg_to_scheduler_metavar
|
||||||
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||||
|
|
||||||
# --lr_scheduler=non_existing_scheduler test
|
# --lr_scheduler=non_existing_scheduler test
|
||||||
unsupported_param = "non_existing_scheduler"
|
unsupported_param = "non_existing_scheduler"
|
||||||
args = {f"--lr_scheduler={unsupported_param}"}
|
args = {f"--lr_scheduler={unsupported_param}"}
|
||||||
with pytest.raises(SystemExit) as excinfo:
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
with CaptureStderr() as cs:
|
with CaptureStderr() as cs:
|
||||||
args = parser.parse_args(args)
|
args = parser.parse_args(args)
|
||||||
assert False, "invalid argument is expected to sys.exit"
|
assert False, "invalid argument is expected to sys.exit"
|
||||||
assert excinfo.type == SystemExit
|
assert excinfo.type == SystemExit
|
||||||
expected = f"invalid choice: '{unsupported_param}'"
|
expected = f"invalid choice: '{unsupported_param}'"
|
||||||
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||||
|
|
||||||
# --lr_scheduler=existing_scheduler test
|
# --lr_scheduler=existing_scheduler test
|
||||||
supported_param = "cosine"
|
supported_param = "cosine"
|
||||||
args_d1 = args_d.copy()
|
args_d1 = args_d.copy()
|
||||||
args_d1["lr_scheduler"] = supported_param
|
args_d1["lr_scheduler"] = supported_param
|
||||||
args = argparse.Namespace(**args_d1)
|
args = argparse.Namespace(**args_d1)
|
||||||
model = main(args)
|
model = main(args)
|
||||||
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
assert (
|
||||||
|
getattr(model.hparams, "lr_scheduler") == supported_param
|
||||||
|
), f"lr_scheduler={supported_param} shouldn't fail"
|
||||||
|
Loading…
Reference in New Issue
Block a user