mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[s2s] Delete useless method, log tokens_per_batch (#6081)
This commit is contained in:
parent
dc4755c6d5
commit
dafa296c95
@ -160,9 +160,16 @@ class SummarizationModule(BaseTransformer):
|
|||||||
)
|
)
|
||||||
return (loss,)
|
return (loss,)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self) -> int:
|
||||||
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx) -> Dict:
|
def training_step(self, batch, batch_idx) -> Dict:
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
|
|
||||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
|
# tokens per batch
|
||||||
|
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
|
||||||
return {"loss": loss_tensors[0], "log": logs}
|
return {"loss": loss_tensors[0], "log": logs}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx) -> Dict:
|
def validation_step(self, batch, batch_idx) -> Dict:
|
||||||
@ -172,7 +179,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
|
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
|
||||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
||||||
rouges.update({k: v.item() for k, v in losses.items()})
|
rouges.update({k: v.item() for k, v in losses.items()})
|
||||||
losses.update(rouges)
|
losses.update(rouges)
|
||||||
@ -190,23 +197,21 @@ class SummarizationModule(BaseTransformer):
|
|||||||
return calculate_rouge(preds, target)
|
return calculate_rouge(preds, target)
|
||||||
|
|
||||||
def _generative_step(self, batch: dict) -> dict:
|
def _generative_step(self, batch: dict) -> dict:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
|
||||||
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
input_ids=source_ids,
|
batch["input_ids"],
|
||||||
attention_mask=source_mask,
|
attention_mask=batch["attention_mask"],
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
decoder_start_token_id=self.decoder_start_token_id,
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
)
|
)
|
||||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||||
preds = self.ids_to_clean_text(generated_ids)
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||||
target = self.ids_to_clean_text(y)
|
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||||
summ_len = np.mean(lmap(len, generated_ids))
|
summ_len = np.mean(lmap(len, generated_ids))
|
||||||
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
|
||||||
return base_metrics
|
return base_metrics
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
|
@ -128,12 +128,6 @@ class Seq2SeqDataset(Dataset):
|
|||||||
def get_char_lens(data_file):
|
def get_char_lens(data_file):
|
||||||
return [len(x) for x in Path(data_file).open().readlines()]
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
|
|
||||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
|
||||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
|
||||||
return source_ids, source_mask, y
|
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||||
|
Loading…
Reference in New Issue
Block a user