mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[s2s] dynamic batch size with --max_tokens_per_batch (#7030)
This commit is contained in:
parent
efeab6a3f1
commit
a5638b2b3a
@ -352,3 +352,33 @@ runtime: 13H on V-100 16GB GPU.
|
||||
```bash
|
||||
pytest examples/seq2seq/
|
||||
```
|
||||
|
||||
|
||||
## Experimental Features
|
||||
These features are harder to use and not always useful.
|
||||
|
||||
### Dynamic Batch Size for MT
|
||||
`finetune.py` has a command line arg `--max_tokens_per_batch` that allows batches to be dynamically sized.
|
||||
This feature can only be used:
|
||||
- with fairseq installed
|
||||
- on 1 GPU
|
||||
- without sortish sampler
|
||||
- after calling `python save_len_file.py $tok $data_dir`
|
||||
|
||||
For example,
|
||||
```bash
|
||||
python save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
|
||||
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
|
||||
```
|
||||
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
|
||||
|
||||
For comparison,
|
||||
```bash
|
||||
./dynamic_bs_example.sh --sortish_sampler --train_batch_size 48
|
||||
```
|
||||
uses 12,723 batches of length 48 and takes slightly more time 9.5 minutes.
|
||||
|
||||
The feature is still experimental, because:
|
||||
+ we can make it much more robust if we have memory mapped/preprocessed datasets.
|
||||
+ The speedup over sortish sampler is not that large at the moment.
|
||||
|
||||
|
17
examples/seq2seq/dynamic_bs_example.sh
Executable file
17
examples/seq2seq/dynamic_bs_example.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export WANDB_PROJECT=dmar
|
||||
export MAX_LEN=128
|
||||
export m=sshleifer/student_marian_en_ro_6_1
|
||||
python finetune.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--fp16 \
|
||||
--data_dir wmt_en_ro \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--train_batch_size=48 --eval_batch_size=64 \
|
||||
--tokenizer_name $m --model_name_or_path $m --num_train_epochs=1 \
|
||||
--warmup_steps 500 --logger_name wandb --gpus 1 \
|
||||
--fp16_opt_level=O1 --task translation \
|
||||
"$@"
|
@ -68,6 +68,12 @@ class SummarizationModule(BaseTransformer):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
if hparams.sortish_sampler and hparams.gpus > 1:
|
||||
hparams.replace_sampler_ddp = False
|
||||
elif hparams.max_tokens_per_batch is not None:
|
||||
if hparams.gpus > 1:
|
||||
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
|
||||
if hparams.sortish_sampler:
|
||||
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
|
||||
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
@ -97,6 +103,10 @@ class SummarizationModule(BaseTransformer):
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
||||
raise AssertionError("Sortish Sampler does not work for multigpu")
|
||||
if self.hparams.sortish_sampler and self.hparams.max_tokens_per_batch is not None:
|
||||
raise AssertionError("max tokens per batch and sortish sampler are incompatible.")
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
@ -175,6 +185,10 @@ class SummarizationModule(BaseTransformer):
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
||||
logs["bs"] = batch["input_ids"].shape[0]
|
||||
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
|
||||
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
|
||||
# TODO(SS): make a wandb summary metric for this
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
@ -253,20 +267,39 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
if self.hparams.sortish_sampler and type_path != "test":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
||||
batch_sampler = dataset.make_dynamic_sampler(
|
||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
# shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
# batch_size=None,
|
||||
)
|
||||
else:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=None,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
@ -313,6 +346,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
|
46
examples/seq2seq/save_len_file.py
Normal file
46
examples/seq2seq/save_len_file.py
Normal file
@ -0,0 +1,46 @@
|
||||
import fire
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .utils import Seq2SeqDataset, pickle_save
|
||||
except ImportError:
|
||||
from utils import Seq2SeqDataset, pickle_save
|
||||
|
||||
|
||||
def save_len_file(
|
||||
tokenizer_name, data_dir, max_source_length=1024, max_target_length=1024, consider_target=False, **kwargs
|
||||
):
|
||||
"""Save max(src_len, tgt_len) for each example to allow dynamic batching."""
|
||||
tok = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
train_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="train", **kwargs)
|
||||
pad = tok.pad_token_id
|
||||
|
||||
def get_lens(ds):
|
||||
dl = tqdm(
|
||||
DataLoader(ds, batch_size=512, num_workers=8, shuffle=False, collate_fn=ds.collate_fn),
|
||||
desc=str(ds.len_file),
|
||||
)
|
||||
max_lens = []
|
||||
for batch in dl:
|
||||
src_lens = batch["input_ids"].ne(pad).sum(1).tolist()
|
||||
tgt_lens = batch["labels"].ne(pad).sum(1).tolist()
|
||||
if consider_target:
|
||||
for src, tgt in zip(src_lens, tgt_lens):
|
||||
max_lens.append(max(src, tgt))
|
||||
else:
|
||||
max_lens.extend(src_lens)
|
||||
return max_lens
|
||||
|
||||
train_lens = get_lens(train_ds)
|
||||
val_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="val", **kwargs)
|
||||
val_lens = get_lens(val_ds)
|
||||
pickle_save(train_lens, train_ds.len_file)
|
||||
pickle_save(val_lens, val_ds.len_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(save_len_file)
|
BIN
examples/seq2seq/test_data/wmt_en_ro/train.len
Normal file
BIN
examples/seq2seq/test_data/wmt_en_ro/train.len
Normal file
Binary file not shown.
@ -1,8 +1,11 @@
|
||||
Corrections to votes and voting intentions: see Minutes Assignment conferred on a Member: see Minutes Membership of committees and delegations: see Minutes Decisions concerning certain documents: see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes
|
||||
Membership of Parliament: see Minutes Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes Verification of credentials: see Minutes Documents received: see Minutes Written statements and oral questions (tabling): see Minutes Petitions: see Minutes Texts of agreements forwarded by the Council: see Minutes Action taken on Parliament's resolutions: see Minutes Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 7.45 p.m.)
|
||||
Election of Vice-Presidents of the European Parliament (deadline for submitting nominations): see Minutes (The sitting was suspended at 12.40 p.m. and resumed at 3.00 p.m.) Election of Quaestors of the European Parliament (deadline for submitting nominations): see Minutes (The sitting was suspended at 3.25 p.m. and resumed at 6.00 p.m.) Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 6.15 p.m.) Opening of the sitting (The sitting was opened at 9.35 a.m.) Documents received: see Minutes Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes
|
||||
Membership of committees (deadline for tabling amendments): see Minutes (The sitting was suspended at 7 p.m. and resumed at 9 p.m.) Agenda for next sitting: see Minutes Closure of sitting (The sitting was suspended at 23.25 p.m.) Documents received: see Minutes Communication of Council common positions: see Minutes (The sitting was suspended at 11.35 a.m. and resumed for voting time at noon) Approval of Minutes of previous sitting: see Minutes Committee of Inquiry into the crisis of the Equitable Life Assurance Society (extension of mandate): see Minutes
|
||||
Announcement by the President: see Minutes 1. Membership of committees (vote) 2. Amendment of the ACP-EC Partnership Agreement (vote) 4. Certification of train drivers operating locomotives and trains on the railway system in the Community (vote) 6. Law applicable to non-contractual obligations ("ROME II") (vote) 8. Seventh and eighth annual reports on arms exports (vote) Corrections to votes and voting intentions: see Minutes Membership of committees and delegations: see Minutes Request for waiver of parliamentary immunity: see Minutes Decisions concerning certain documents: see Minutes
|
||||
Written statements for entry
|
||||
Written statements for entry in the register (Rule 116): see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes Adjournment of the session I declare the session of the European Parliament adjourned. (The sitting was closed at 1 p.m.) Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes Request for the defence of parliamentary immunity: see Minutes Appointments to committees (proposal by the Conference of Presidents): see Minutes Documents received: see Minutes Texts of agreements forwarded by the Council: see Minutes
|
||||
Action taken on Parliament's resolutions: see Minutes Oral questions and written statements (tabling): see Minutes Written statements (Rule 116): see Minutes Agenda: see Minutes 1. Appointments to parliamentary committees (vote): see Minutes Voting time Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 12 midnight) Opening of the sitting (The sitting was opened at 09.05) Documents received: see Minutes Approval of Minutes of previous sitting: see Minutes 1. Protection of passengers against displaced luggage (vote) 2.
|
||||
Approval of motor vehicles with regard to the forward field of vision of the driver (vote) 3. EC-Korea Agreement on scientific and technological cooperation (vote) 4. Mainstreaming sustainability in development cooperation policies (vote) 5. Draft Amending Budget No 1/2007 (vote) 7. EC-Gabon Fisheries Partnership (vote) 10. Limitation periods in cross-border disputes involving personal injuries and fatal accidents (vote) 12. Strategy for a strengthened partnership with the Pacific Islands (vote) 13. The European private company statute (vote) That concludes the vote.
|
||||
Corrections to votes and voting intentions: see Minutes Assignment conferred on a Member: see Minutes Membership of committees and delegations: see Minutes Decisions concerning certain documents: see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes
|
||||
Corrections to votes and voting intentions: see Minutes Assignment conferred on a Member: see Minutes Membership of committees and delegations: see Minutes Decisions concerning certain documents: see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes
|
||||
Written statements for entry
|
||||
|
@ -1,8 +1,11 @@
|
||||
Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Misiune încredinţată unui deputat: consultaţi procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal
|
||||
Componenţa Parlamentului: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal Verificarea prerogativelor: a se vedea procesul-verbal Depunere de documente: a se vedea procesul-verbal Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal Petiţii: a se vedea procesul-verbal Transmiterea de către Consiliu a textelor acordurilor: a se vedea procesul-verbal Cursul dat rezoluţiilor Parlamentului: a se vedea procesul-verbal Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Se levanta la sesión a las 19.45 horas)
|
||||
Alegerea vicepreşedinţilor Parlamentului European (termenul de depunere a candidaturilor): consultaţi procesul-verbal (Die Sitzung wird um 12.40 Uhr unterbrochen und um 15.00 Uhr wiederaufgenommen). Alegerea chestorilor Parlamentului European (termenul de depunere a candidaturilor): consultaţi procesul-verbal (Die Sitzung wird um 15.25 Uhr unterbrochen und um 18.00 Uhr wiederaufgenommen). Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Die Sitzung wird um 18.15 Uhr geschlossen.) Deschiderea şedinţei (Die Sitzung wird um 9.35 Uhr eröffnet.) Depunerea documentelor: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal
|
||||
Componenţa comisiilor (termenul de depunere a amendamentelor): consultaţi procesul-verbal (La seduta, sospesa alle 19.00, è ripresa alle 21.00) Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Die Sitzung wird um 23.25 Uhr geschlossen.) Depunerea documentelor: a se vedea procesul-verbal Comunicarea poziţiilor comune ale Parlamentului: a se vedea procesul-verbal (La séance, suspendue à 11h35 dans l'attente de l'Heure des votes, est reprise à midi) Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Comisia de anchetă privind criza societăţii de asigurări "Equitable Life” (prelungirea mandatului): consultaţi procesul-verbal
|
||||
Comunicarea Preşedintelui: consultaţi procesul-verbal 1. Componenţa comisiilor (vot) 2. Modificarea Acordului de parteneriat ACP-CE ("Acordul de la Cotonou”) (vot) 4. Certificarea mecanicilor de locomotivă care conduc locomotive şi trenuri în sistemul feroviar comunitar (vot) 6. Legea aplicabilă obligaţiilor necontractuale ("Roma II”) (vot) 8. Al şaptelea şi al optulea raport anual privind exportul de armament (vot) Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Cerere de ridicare a imunităţii parlamentare: consultaţi procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal
|
||||
Declaraţii scrise înscrise
|
||||
Declaraţii scrise înscrise în registru (articolul 116 din Regulamentul de procedură): a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal Întreruperea sesiunii Dichiaro interrotta la sessione del Parlamento europeo. (La seduta è tolta alle 13.00) Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal Cerere de apărare a imunităţii parlamentare: consultaţi procesul-verbal Numiri în comisii (propunerea Conferinţei preşedinţilor): consultaţi procesul-verbal Depunerea documentelor: a se vedea procesul-verbal Transmiterea de către Consiliu a textelor acordurilor: a se vedea procesul-verbal
|
||||
Continuări ale rezoluţiilor Parlamentului: consultaţi procesul-verbal Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal Declaraţii scrise (articolul 116 din Regulamentul de procedură) Ordinea de zi: a se vedea procesul-verbal 1. Numiri în comisiile parlamentare (vot): consultaţi procesul-verbal Timpul afectat votului Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (La seduta è tolta alle 24.00) Deschiderea şedinţei (The sitting was opened at 09.05) Depunerea documentelor: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal 1. Protecţia pasagerilor împotriva deplasării bagajelor (vot) 2.
|
||||
Omologarea vehiculelor cu motor cu privire la câmpul de vizibilitate înainte al conducătorului auto (vot) 3. Acordul CE-Coreea de cooperare ştiinţifică şi tehnologică (vot) 4. Integrarea durabilităţii în politicile de cooperare pentru dezvoltare (vot) 5. Proiect de buget rectificativ nr.1/2007 (vot) 7. Acordul de parteneriat în domeniul pescuitului între Comunitatea Europeană şi Republica Gaboneză (vot) 10. Termenele de prescripţie aplicabile în cadrul litigiilor transfrontaliere cu privire la vătămările corporale şi accidentele mortale (vot) 12. Relaţiile UE cu insulele din Pacific: Strategie pentru un parteneriat consolidat (vot) 13. Statutul societăţii private europene (vot) Damit ist die Abstimmungsstunde beendet.
|
||||
Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Misiune încredinţată unui deputat: consultaţi procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal
|
||||
Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Misiune încredinţată unui deputat: consultaţi procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal
|
||||
Declaraţii scrise înscrise
|
||||
|
BIN
examples/seq2seq/test_data/wmt_en_ro/val.len
Normal file
BIN
examples/seq2seq/test_data/wmt_en_ro/val.len
Normal file
Binary file not shown.
188
examples/seq2seq/test_datasets.py
Normal file
188
examples/seq2seq/test_datasets.py
Normal file
@ -0,0 +1,188 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .save_len_file import save_len_file
|
||||
from .test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||
from .utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||
|
||||
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Will be truncated
|
||||
assert max_len_source > max_src_len # Will be truncated
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_src_len,
|
||||
max_target_length=max_tgt_len, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["labels"].shape[1] == max_tgt_len
|
||||
if tok_name != MBART_TINY:
|
||||
continue
|
||||
# check language codes in correct place
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = LegacySeq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
tmp_dir = Path(make_test_data_dir())
|
||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||
new_paths = {x.name for x in save_dir.iterdir()}
|
||||
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||
assert len(packed_examples) < len(orig_examples)
|
||||
assert len(packed_examples) == 1
|
||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
|
||||
def test_dynamic_batch_size():
|
||||
if not FAIRSEQ_AVAILABLE:
|
||||
return
|
||||
ds, max_tokens, tokenizer = _get_dataset(max_len=64)
|
||||
required_batch_size_multiple = 64
|
||||
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
|
||||
batch_sizes = [len(x) for x in batch_sampler]
|
||||
assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length
|
||||
assert sum(batch_sizes) == len(ds) # no dropped or added examples
|
||||
data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2)
|
||||
failures = []
|
||||
num_src_per_batch = []
|
||||
for batch in data_loader:
|
||||
src_shape = batch["input_ids"].shape
|
||||
bs = src_shape[0]
|
||||
assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple
|
||||
num_src_tokens = np.product(batch["input_ids"].shape)
|
||||
num_src_per_batch.append(num_src_tokens)
|
||||
if num_src_tokens > (max_tokens * 1.1):
|
||||
failures.append(num_src_tokens)
|
||||
assert num_src_per_batch[0] == max(num_src_per_batch)
|
||||
if failures:
|
||||
raise AssertionError(f"too many tokens in {len(failures)} batches")
|
||||
|
||||
|
||||
def test_sortish_sampler_reduces_padding():
|
||||
ds, _, tokenizer = _get_dataset(max_len=512)
|
||||
bs = 2
|
||||
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
|
||||
|
||||
naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2)
|
||||
sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler)
|
||||
|
||||
pad = tokenizer.pad_token_id
|
||||
|
||||
def count_pad_tokens(data_loader, k="input_ids"):
|
||||
return [batch[k].eq(pad).sum().item() for batch in data_loader]
|
||||
|
||||
assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels"))
|
||||
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
|
||||
assert len(sortish_dl) == len(naive_dl)
|
||||
|
||||
|
||||
def _get_dataset(n_obs=1000, max_len=128):
|
||||
if os.getenv("USE_REAL_DATA", False):
|
||||
data_dir = "examples/seq2seq/wmt_en_ro"
|
||||
max_tokens = max_len * 2 * 64
|
||||
if not Path(data_dir).joinpath("train.len").exists():
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
else:
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
max_tokens = max_len * 4
|
||||
save_len_file(MARIAN_TINY, data_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY)
|
||||
ds = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=data_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_len,
|
||||
max_target_length=max_len,
|
||||
n_obs=n_obs,
|
||||
)
|
||||
return ds, max_tokens, tokenizer
|
||||
|
||||
|
||||
def test_distributed_sortish_sampler_splits_indices_between_procs():
|
||||
ds, max_tokens, tokenizer = _get_dataset()
|
||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||
assert ids1.intersection(ids2) == set()
|
@ -10,21 +10,18 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import lightning_base
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||
|
||||
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .run_eval_search import run_search
|
||||
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
from .utils import label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -32,6 +29,7 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"max_tokens_per_batch": None,
|
||||
"supervise_forward": True,
|
||||
"normalize_hidden": True,
|
||||
"label_smoothing": 0.2,
|
||||
@ -106,8 +104,7 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@ -530,96 +527,3 @@ def test_finetune_lr_schedulers():
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
|
||||
|
||||
|
||||
def test_pack_dataset():
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
||||
|
||||
tmp_dir = Path(make_test_data_dir())
|
||||
orig_examples = tmp_dir.joinpath("train.source").open().readlines()
|
||||
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
||||
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
||||
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
||||
new_paths = {x.name for x in save_dir.iterdir()}
|
||||
packed_examples = save_dir.joinpath("train.source").open().readlines()
|
||||
# orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.']
|
||||
# desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.']
|
||||
assert len(packed_examples) < len(orig_examples)
|
||||
assert len(packed_examples) == 1
|
||||
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Will be truncated
|
||||
assert max_len_source > max_src_len # Will be truncated
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=max_src_len,
|
||||
max_target_length=max_tgt_len, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["labels"].shape[1] == max_tgt_len
|
||||
if tok_name != MBART_TINY:
|
||||
continue
|
||||
# check language codes in correct place
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = LegacySeq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
@ -21,6 +21,14 @@ from transformers import BartTokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.data.data_utils import batch_by_size
|
||||
|
||||
FAIRSEQ_AVAILABLE = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
FAIRSEQ_AVAILABLE = False
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
"""From fairseq"""
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
@ -94,7 +102,13 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
||||
if os.path.exists(self.len_file):
|
||||
self.src_lens = pickle_load(self.len_file)
|
||||
self.used_char_len = False
|
||||
else:
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.used_char_len = True
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
@ -115,12 +129,42 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
@cached_property
|
||||
def tgt_lens(self):
|
||||
"""Length in characters of target documents"""
|
||||
return self.get_char_lens(self.tgt_file)
|
||||
|
||||
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
||||
if distributed:
|
||||
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
||||
else:
|
||||
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
||||
|
||||
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
||||
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
||||
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
|
||||
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
||||
|
||||
def num_tokens_in_example(i):
|
||||
return min(self.src_lens[i], self.max_target_length)
|
||||
|
||||
# call fairseq cython function
|
||||
batch_sampler: List[List[int]] = batch_by_size(
|
||||
sorted_indices,
|
||||
num_tokens_fn=num_tokens_in_example,
|
||||
max_tokens=max_tokens_per_batch,
|
||||
required_batch_size_multiple=64,
|
||||
)
|
||||
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
|
||||
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
||||
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
|
||||
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
||||
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
||||
shuffled_batches[largest_batch_idx],
|
||||
shuffled_batches[0],
|
||||
)
|
||||
return shuffled_batches
|
||||
|
||||
def __getitem__(self, item):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user