[docs] Add integration test example to copy pasta template (#5961)

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sam Shleifer 2020-07-22 12:48:38 -04:00 committed by GitHub
parent 01116d3c5b
commit feeb956a19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 287 additions and 222 deletions

View File

@ -215,7 +215,7 @@ Follow these steps to start contributing:
`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
- If you are adding a new tokenizer, write tests, and make sure
`RUN_SLOW=1 python -m pytest tests/test_tokenization_{your_model_name}.py` passes.
CircleCI does not run the slow tests.
CircleCI does not run the slow tests, but github actions does every night!
6. All public methods must have informative docstrings that work nicely with sphinx. See `modeling_ctrl.py` for an
example.
@ -239,6 +239,16 @@ $ pip install -r examples/requirements.txt # only needed the first time
$ python -m pytest -n auto --dist=loadfile -s -v ./examples/
```
and for the slow tests:
```bash
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
or
```python
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
In fact, that's how `make test` and `make test-examples` are implemented!
You can specify a smaller set of tests in order to test only the feature

View File

@ -5,6 +5,7 @@ import unittest
from unittest.mock import patch
import run_glue_deebert
from transformers.testing_utils import slow
logging.basicConfig(level=logging.DEBUG)
@ -20,6 +21,7 @@ def get_setup_file():
class DeeBertTests(unittest.TestCase):
@slow
def test_glue_deebert(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

View File

@ -128,3 +128,11 @@ if _torch_available:
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
else:
torch_device = None
def require_torch_and_cuda(test_case):
"""Decorator marking a test that requires CUDA and PyTorch). """
if torch_device != "cuda":
return unittest.skip("test requires CUDA")
else:
return test_case

View File

@ -1,5 +1,15 @@
# How to add a new example script in 🤗Transformers
This folder provide a template for adding a new example script implementing a training or inference task with the models in the 🤗Transformers library.
Add tests!
Currently only examples for PyTorch are provided which are adaptations of the library's SQuAD examples which implement single-GPU and distributed training with gradient accumulation and mixed-precision (using NVIDIA's apex library) to cover a reasonable range of use cases.
These folder can be put in a subdirectory under your example's name, like `examples/deebert`.
Best Practices:
- use `Trainer`/`TFTrainer`
- write an @slow test that checks that your model can train on one batch and get a low loss.
- this test should use cuda if it's available. (e.g. by checking `transformers.torch_device`)
- adding an `eval_xxx.py` script that can evaluate a pretrained checkpoint.
- tweet about your new example with a carbon screenshot of how to run it and tag @huggingface

View File

@ -18,6 +18,7 @@ Here an overview of the general workflow:
- [ ] add model/configuration/tokenization classes
- [ ] add conversion scripts
- [ ] add tests
- [ ] add @slow integration test
- [ ] finalize
Let's detail what should be done at each step

View File

@ -347,23 +347,7 @@ class XxxModel(XxxPreTrainedModel):
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We create a 3D attention mask from a 2D tensor mask.
# (this can be done with self.invert_attention_mask)
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N

View File

@ -20,7 +20,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device
from .utils import require_torch, require_torch_and_cuda, slow, torch_device
if is_torch_available():
@ -31,8 +31,207 @@ if is_torch_available():
XxxForQuestionAnswering,
XxxForSequenceClassification,
XxxForTokenClassification,
AutoModelForMaskedLM,
AutoTokenizer,
)
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.file_utils import cached_property
#
class XxxModelTester:
"""You can also import this e.g from .test_modeling_bart import BartModelTester """
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_xxx_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_xxx_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_xxx_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = XxxForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_and_check_xxx_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = XxxForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
@ -44,204 +243,8 @@ class XxxModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
class XxxModelTester(object):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_xxx_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_xxx_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_xxx_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = XxxForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_and_check_xxx_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = XxxForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = XxxModelTest.XxxModelTester(self)
self.model_tester = XxxModelTester(self)
self.config_tester = ConfigTester(self, config_class=XxxConfig, hidden_size=37)
def test_config(self):
@ -268,7 +271,50 @@ class XxxModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_xxx_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in XXX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
def test_lm_outputs_same_as_reference_model(self):
"""Write something that could help someone fixing this here."""
checkpoint_path = "XXX/bart-large"
model = self.big_model
tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path
) # same with AutoTokenizer (see tokenization_auto.py). This is not mandatory
# MODIFY THIS DEPENDING ON YOUR MODELS RELEVANT TASK.
batch = tokenizer(["I went to the <mask> yesterday"]).to(torch_device)
desired_mask_result = tokenizer.decode("store") # update this
logits = model(**batch).logits
masked_index = (batch.input_ids == self.tokenizer.mask_token_id).nonzero()
assert model.num_parameters() == 175e9 # a joke
mask_entry_logits = logits[0, masked_index.item(), :]
probs = mask_entry_logits.softmax(dim=0)
_, predictions = probs.topk(1)
self.assertEqual(tokenizer.decode(predictions), desired_mask_result)
@cached_property
def big_model(self):
"""Cached property means this code will only be executed once."""
checkpoint_path = "XXX/bart-large"
model = AutoModelForMaskedLM.from_pretrained(checkpoint_path).to(
torch_device
) # test whether AutoModel can determine your model_class from checkpoint name
if torch_device == "cuda":
model.half()
# optional: do more testing! This will save you time later!
@slow
def test_that_XXX_can_be_used_in_a_pipeline(self):
"""We can use self.big_model here without calling __init__ again."""
pass
def test_XXX_loss_doesnt_change_if_you_add_padding(self):
pass
def test_XXX_bad_args(self):
pass
def test_XXX_backward_pass_reduces_loss(self):
"""Test loss/gradients same as reference implementation, for example."""
pass
@require_torch_and_cuda
def test_large_inputs_in_fp16_dont_cause_overflow(self):
pass

View File

@ -62,3 +62,7 @@ class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_special_tokens_as_you_expect(self):
"""If you are training a seq2seq model that expects a decoder_prefix token make sure it is prepended to decoder_input_ids """
pass