From dafa296c952c08fca3686f1cf8f3a8f8eb116744 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 28 Jul 2020 11:24:23 -0400 Subject: [PATCH] [s2s] Delete useless method, log tokens_per_batch (#6081) --- examples/seq2seq/finetune.py | 23 ++++++++++++++--------- examples/seq2seq/utils.py | 6 ------ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index e2e9ecffa26..c7138295460 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -160,9 +160,16 @@ class SummarizationModule(BaseTransformer): ) return (loss,) + @property + def pad(self) -> int: + return self.tokenizer.pad_token_id + def training_step(self, batch, batch_idx) -> Dict: loss_tensors = self._step(batch) + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + # tokens per batch + logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum() return {"loss": loss_tensors[0], "log": logs} def validation_step(self, batch, batch_idx) -> Dict: @@ -172,7 +179,7 @@ class SummarizationModule(BaseTransformer): self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] - rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]} + rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) rouges.update({k: v.item() for k, v in losses.items()}) losses.update(rouges) @@ -190,23 +197,21 @@ class SummarizationModule(BaseTransformer): return calculate_rouge(preds, target) def _generative_step(self, batch: dict) -> dict: - pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id) t0 = time.time() generated_ids = self.model.generate( - input_ids=source_ids, - attention_mask=source_mask, + batch["input_ids"], + attention_mask=batch["attention_mask"], use_cache=True, decoder_start_token_id=self.decoder_start_token_id, ) - gen_time = (time.time() - t0) / source_ids.shape[0] - preds = self.ids_to_clean_text(generated_ids) - target = self.ids_to_clean_text(y) + gen_time = (time.time() - t0) / batch["input_ids"].shape[0] + preds: List[str] = self.ids_to_clean_text(generated_ids) + target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) loss_tensors = self._step(batch) base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} rouge: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) - base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge) + base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) return base_metrics def test_step(self, batch, batch_idx): diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 49910ab6216..7d9288333c9 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -128,12 +128,6 @@ class Seq2SeqDataset(Dataset): def get_char_lens(data_file): return [len(x) for x in Path(data_file).open().readlines()] - @staticmethod - def trim_seq2seq_batch(batch, pad_token_id) -> tuple: - y = trim_batch(batch["decoder_input_ids"], pad_token_id) - source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) - return source_ids, source_mask, y - def collate_fn(self, batch) -> Dict[str, torch.Tensor]: input_ids = torch.stack([x["input_ids"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch])