mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-27 16:22:23 +06:00

* Use tokenizers pre-tokenized pipeline * failing pretrokenized test * Fix is_pretokenized in python * add pretokenized tests * style and quality * better tests for batched pretokenized inputs * tokenizers clean up - new padding_strategy - split the files * [HUGE] refactoring tokenizers - padding - truncation - tests * style and quality * bump up requied tokenizers version to 0.8.0-rc1 * switched padding/truncation API - simpler better backward compat * updating tests for custom tokenizers * style and quality - tests on pad * fix QA pipeline * fix backward compatibility for max_length only * style and quality * Various cleans up - add verbose * fix tests * update docstrings * Fix tests * Docs reformatted * __call__ method documented Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
import os
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
|
examples = []
|
|
with open(data_path, "r") as f:
|
|
for text in f.readlines():
|
|
tokenized = tokenizer.batch_encode_plus(
|
|
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
|
|
)
|
|
examples.append(tokenized)
|
|
return examples
|
|
|
|
|
|
def trim_batch(
|
|
input_ids, pad_token_id, attention_mask=None,
|
|
):
|
|
"""Remove columns that are populated exclusively by pad_token_id"""
|
|
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
|
if attention_mask is None:
|
|
return input_ids[:, keep_column_mask]
|
|
else:
|
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
|
|
|
|
|
class SummarizationDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
data_dir="./cnn-dailymail/cnn_dm/",
|
|
type_path="train",
|
|
max_source_length=1024,
|
|
max_target_length=56,
|
|
):
|
|
super().__init__()
|
|
self.tokenizer = tokenizer
|
|
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
|
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
|
|
|
def __len__(self):
|
|
return len(self.source)
|
|
|
|
def __getitem__(self, index):
|
|
source_ids = self.source[index]["input_ids"].squeeze()
|
|
target_ids = self.target[index]["input_ids"].squeeze()
|
|
src_mask = self.source[index]["attention_mask"].squeeze()
|
|
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
|
|
|
@staticmethod
|
|
def trim_seq2seq_batch(batch, pad_token_id):
|
|
y = trim_batch(batch["target_ids"], pad_token_id)
|
|
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
|
|
return source_ids, source_mask, y
|
|
|
|
def collate_fn(self, batch):
|
|
input_ids = torch.stack([x["source_ids"] for x in batch])
|
|
masks = torch.stack([x["source_mask"] for x in batch])
|
|
target_ids = torch.stack([x["target_ids"] for x in batch])
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
y = trim_batch(target_ids, pad_token_id)
|
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
|
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
|