transformers/examples/summarization/bart/utils.py
2020-04-07 19:05:58 -04:00

57 lines
2.2 KiB
Python

import os
import torch
from torch.utils.data import Dataset
from transformers.tokenization_utils import trim_batch
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
examples = []
with open(data_path, "r") as f:
for text in f.readlines():
tokenized = tokenizer.batch_encode_plus(
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
)
examples.append(tokenized)
return examples
class SummarizationDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir="./cnn-dailymail/cnn_dm/",
type_path="train",
max_source_length=1024,
max_target_length=56,
):
super().__init__()
self.tokenizer = tokenizer
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
def __len__(self):
return len(self.source)
def __getitem__(self, index):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze()
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
y = trim_batch(batch["target_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch):
input_ids = torch.stack([x["source_ids"] for x in batch])
masks = torch.stack([x["source_mask"] for x in batch])
target_ids = torch.stack([x["target_ids"] for x in batch])
pad_token_id = self.tokenizer.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}