transformers/tests/test_trainer.py
Julien Chaumond dd9d483d03
Trainer (#3800)
* doc

* [tests] Add sample files for a regression task

* [HUGE] Trainer

* Feedback from @sshleifer

* Feedback from @thomwolf + logging tweak

* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes

* [glue] Use default max_seq_length of 128 like before

* [glue] move DataTrainingArguments around

* [ner] Change interface of InputExample, and align run_{tf,pl}

* Re-align the pl scripts a little bit

* ner

* [ner] Add integration test

* Fix language_modeling with API tweak

* [ci] Tweak loss target

* Don't break console output

* amp.initialize: model must be on right device before

* [multiple-choice] update for Trainer

* Re-align to 827d6d6ef0
2020-04-21 20:11:56 -04:00

110 lines
4.6 KiB
Python

import unittest
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
from .utils import require_torch
if is_torch_available():
import torch
from transformers import (
Trainer,
LineByLineTextDataset,
AutoModelForSequenceClassification,
DefaultDataCollator,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
TextDataset,
)
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch
class DataCollatorIntegrationTest(unittest.TestCase):
def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./examples/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self):
MODEL_ID = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="sts-b", data_dir="./examples/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# ^ causal lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator.collate_batch(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_lm_tokenizer_with_padding(self):
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
data_collator = DataCollatorForLanguageModeling(tokenizer)
# ^ masked lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((2, 512)))
@require_torch
class TrainerIntegrationTest(unittest.TestCase):
def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./examples/tests_samples/MRPC", overwrite_cache=True
)
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
result = trainer.evaluate()
self.assertLess(result["loss"], 0.2)
def test_trainer_eval_lm(self):
MODEL_ID = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
dataset = LineByLineTextDataset(
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
)
self.assertEqual(len(dataset), 31)