mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[examples/seq2seq]: add --label_smoothing option (#5919)
This commit is contained in:
parent
95d1962b9c
commit
5b193b39b0
@ -27,8 +27,18 @@ this should make a directory called `cnn_dm/` with files like `test.source`.
|
|||||||
```
|
```
|
||||||
|
|
||||||
WMT16 English-Romanian Translation Data:
|
WMT16 English-Romanian Translation Data:
|
||||||
|
|
||||||
|
This dataset comes in two formats. The "packed" version merges short training examples into examples of <200 tokens to increase GPU utilization (and also improves validation performance).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd examples/seq2seq
|
cd examples/seq2seq
|
||||||
|
https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro_packed_train_200.tgz
|
||||||
|
tar -xzvf wmt_en_ro_packed_200.tgz
|
||||||
|
export ENRO_DIR=wmt_en_ro_packed_train_200
|
||||||
|
```
|
||||||
|
|
||||||
|
The original data can also be downloaded with this command:
|
||||||
|
```bash
|
||||||
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
||||||
tar -xzvf wmt_en_ro.tar.gz
|
tar -xzvf wmt_en_ro.tar.gz
|
||||||
export ENRO_DIR=${PWD}/wmt_en_ro
|
export ENRO_DIR=${PWD}/wmt_en_ro
|
||||||
@ -84,16 +94,31 @@ The following command should work on a 16GB GPU:
|
|||||||
|
|
||||||
First, follow the wmt_en_ro download instructions.
|
First, follow the wmt_en_ro download instructions.
|
||||||
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||||
|
|
||||||
|
Best performing command:
|
||||||
```bash
|
```bash
|
||||||
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
|
# optionally
|
||||||
export MAX_LEN=128
|
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
|
||||||
|
# export WANDB_PROJECT="MT" # optional
|
||||||
|
export MAX_LEN=200
|
||||||
export BS=4
|
export BS=4
|
||||||
export GAS=8
|
export GAS=8 # gradient accumulation steps
|
||||||
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
|
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
|
||||||
```
|
```
|
||||||
|
This should take < 2h/epoch on a 16GB v100 and achieve val_avg_ BLEU score above 25. (you can see in wandb or metrics.json).
|
||||||
|
To get results in line with fairseq, you need to do some postprocessing.
|
||||||
|
|
||||||
|
MultiGPU command
|
||||||
|
(using 8 GPUS as an example)
|
||||||
|
```bash
|
||||||
|
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
|
||||||
|
# export WANDB_PROJECT="MT" # optional
|
||||||
|
export MAX_LEN=200
|
||||||
|
export BS=4
|
||||||
|
export GAS=1 # gradient accumulation steps
|
||||||
|
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
|
||||||
|
```
|
||||||
### Finetuning Outputs
|
### Finetuning Outputs
|
||||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||||
@ -108,7 +133,7 @@ output_dir
|
|||||||
│ ├── tokenizer_config.json
|
│ ├── tokenizer_config.json
|
||||||
│ └── vocab.json
|
│ └── vocab.json
|
||||||
├── git_log.json # repo, branch, and commit hash
|
├── git_log.json # repo, branch, and commit hash
|
||||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
|
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
|
||||||
├── metrics.json # new validation metrics will continually be appended to this
|
├── metrics.json # new validation metrics will continually be appended to this
|
||||||
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||||
│ ├── config.json
|
│ ├── config.json
|
||||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.utilities import rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|
||||||
|
|
||||||
|
def get_early_stopping_callback(metric, patience):
|
||||||
|
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)
|
||||||
|
@ -33,9 +33,10 @@ try:
|
|||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
MBartDataset,
|
MBartDataset,
|
||||||
|
label_smoothed_nll_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import (
|
from utils import (
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
@ -52,8 +53,9 @@ except ImportError:
|
|||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
|
label_smoothed_nll_loss,
|
||||||
)
|
)
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def _step(self, batch: dict) -> Tuple:
|
def _step(self, batch: dict) -> Tuple:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||||
y_ids = y[:, :-1].contiguous()
|
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||||
lm_labels = y[:, 1:].clone()
|
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
|
||||||
loss = outputs[0]
|
if self.hparams.label_smoothing == 0:
|
||||||
|
# Same behavior as modeling_bart.py
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
lm_logits = outputs[0]
|
||||||
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
|
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
|
||||||
|
else:
|
||||||
|
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||||
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
|
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
|
)
|
||||||
return (loss,)
|
return (loss,)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx) -> Dict:
|
def training_step(self, batch, batch_idx) -> Dict:
|
||||||
@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--early_stopping_patience",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
required=False,
|
||||||
|
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
elif args.logger_name == "wandb":
|
elif args.logger_name == "wandb":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||||
|
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||||
|
|
||||||
elif args.logger_name == "wandb_shared":
|
elif args.logger_name == "wandb_shared":
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||||
|
|
||||||
|
if args.early_stopping_patience >= 0:
|
||||||
|
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||||
|
else:
|
||||||
|
es_callback = False
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
logging_callback=Seq2SeqLoggingCallback(),
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||||
|
early_stopping_callback=es_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
# TODO: early stopping callback seems messed up
|
# TODO: early stopping callback seems messed up
|
||||||
)
|
)
|
||||||
|
@ -12,14 +12,14 @@ import torch
|
|||||||
from pytest import param
|
from pytest import param
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import AutoTokenizer, MBartTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
||||||
from transformers.testing_utils import require_multigpu
|
from transformers.testing_utils import require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import main
|
from .finetune import main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -27,7 +27,8 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"label_smoothing_eps": 0.2,
|
"label_smoothing": 0.2,
|
||||||
|
"early_stopping_patience": 2,
|
||||||
"logger_name": "default",
|
"logger_name": "default",
|
||||||
"length_penalty": 0.5,
|
"length_penalty": 0.5,
|
||||||
"cache_dir": "",
|
"cache_dir": "",
|
||||||
@ -142,6 +143,26 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
|
||||||
|
def test_loss_fn(self):
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||||
|
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||||
|
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||||
|
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||||
|
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||||
|
model_computed_loss = model(
|
||||||
|
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
|
||||||
|
).loss
|
||||||
|
|
||||||
|
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||||
|
|
||||||
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||||
|
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||||
|
)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# TODO: understand why this breaks
|
||||||
|
self.assertEqual(nll_loss, model_computed_loss)
|
||||||
|
|
||||||
@unittest.skip("T5 distillation is broken at the moment")
|
@unittest.skip("T5 distillation is broken at the moment")
|
||||||
def test_distill_t5(self):
|
def test_distill_t5(self):
|
||||||
updates = dict(
|
updates = dict(
|
||||||
@ -156,6 +177,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_distiller_cli(self, updates, check_contents=True):
|
def _test_distiller_cli(self, updates, check_contents=True):
|
||||||
default_updates = dict(
|
default_updates = dict(
|
||||||
|
label_smoothing_eps=0.0,
|
||||||
|
early_stopping_patience=-1,
|
||||||
train_batch_size=1,
|
train_batch_size=1,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
max_epochs=2,
|
max_epochs=2,
|
||||||
@ -215,6 +238,8 @@ def test_run_eval_bart(model):
|
|||||||
def test_finetune(model):
|
def test_finetune(model):
|
||||||
args_d: dict = CHEAP_ARGS.copy()
|
args_d: dict = CHEAP_ARGS.copy()
|
||||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||||
|
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||||
|
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||||
args_d.update(
|
args_d.update(
|
||||||
|
@ -4,17 +4,17 @@ export PYTHONPATH="../":"${PYTHONPATH}"
|
|||||||
python finetune.py \
|
python finetune.py \
|
||||||
--learning_rate=3e-5 \
|
--learning_rate=3e-5 \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--gpus 1 \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--val_check_interval 0.1 \
|
--val_check_interval 0.25 \
|
||||||
--adam_eps 1e-06 \
|
--adam_eps 1e-06 \
|
||||||
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
|
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||||
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
|
--data_dir $ENRO_DIR \
|
||||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||||
--model_name_or_path facebook/mbart-large-cc25 \
|
|
||||||
--task translation \
|
--task translation \
|
||||||
--warmup_steps 500 \
|
--warmup_steps 500 \
|
||||||
--logger_name wandb --sortish_sampler \
|
--freeze_embeds \
|
||||||
|
--early_stopping_patience 4 \
|
||||||
|
--model_name_or_path facebook/mbart-large-cc25 \
|
||||||
$@
|
$@
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
|
||||||
# Need to export N_GPUS=
|
|
||||||
python finetune.py \
|
|
||||||
--learning_rate=3e-5 \
|
|
||||||
--fp16 \
|
|
||||||
--gpus $N_GPUS \
|
|
||||||
--do_train \
|
|
||||||
--val_check_interval 0.25 \
|
|
||||||
--adam_eps 1e-06 \
|
|
||||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
|
||||||
--data_dir $ENRO_DIR \
|
|
||||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
|
||||||
--tokenizer facebook/mbart-large-cc25 \
|
|
||||||
--task translation \
|
|
||||||
--warmup_steps 500 --freeze_encoder --freeze_embeds \
|
|
||||||
$@
|
|
@ -19,6 +19,29 @@ from torch.utils.data import Dataset, Sampler
|
|||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||||
|
"""From fairseq"""
|
||||||
|
if target.dim() == lprobs.dim() - 1:
|
||||||
|
target = target.unsqueeze(-1)
|
||||||
|
nll_loss = -lprobs.gather(dim=-1, index=target)
|
||||||
|
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
||||||
|
if ignore_index is not None:
|
||||||
|
pad_mask = target.eq(ignore_index)
|
||||||
|
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||||
|
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||||
|
bs = pad_mask.long().sum()
|
||||||
|
else:
|
||||||
|
nll_loss = nll_loss.squeeze(-1)
|
||||||
|
smooth_loss = smooth_loss.squeeze(-1)
|
||||||
|
bs = lprobs.shape[0]
|
||||||
|
|
||||||
|
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
||||||
|
smooth_loss = smooth_loss.sum()
|
||||||
|
eps_i = epsilon / lprobs.size(-1)
|
||||||
|
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
||||||
|
return loss / bs, nll_loss / bs
|
||||||
|
|
||||||
|
|
||||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||||
return tokenizer(
|
return tokenizer(
|
||||||
@ -144,8 +167,8 @@ class MBartDataset(Seq2SeqDataset):
|
|||||||
assert source_line, f"empty source line for index {index}"
|
assert source_line, f"empty source line for index {index}"
|
||||||
assert tgt_line, f"empty tgt line for index {index}"
|
assert tgt_line, f"empty tgt line for index {index}"
|
||||||
return {
|
return {
|
||||||
"tgt_texts": source_line,
|
"tgt_texts": tgt_line,
|
||||||
"src_texts": tgt_line,
|
"src_texts": source_line,
|
||||||
}
|
}
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
Loading…
Reference in New Issue
Block a user