[BigBird Pegasus] Make tests faster (#11744)

* improve tests

* remove bogus file

* make style

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-05-17 11:30:53 +01:00 committed by GitHub
parent a0531c8a24
commit 73893fc771
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -368,17 +368,24 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1)
def _check_batched_forward(self, attn_type, tolerance=1e-3):
config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type)
config, _ = self.model_tester.prepare_config_and_inputs()
config.max_position_embeddings = 128
config.block_size = 16
config.attention_type = attn_type
model = BigBirdPegasusForConditionalGeneration(config).to(torch_device)
model.eval()
sample_with_padding = [3, 8, 11] * 128 + [0] * 128
sample_without_padding = [4, 7, 9, 13] * 128
chunk_length = 32
sample_with_padding = [3, 8, 11] * chunk_length + [0] * chunk_length
sample_without_padding = [4, 7, 9, 13] * chunk_length
target_ids_without_padding = [2, 3] * 8
target_ids_with_padding = [7, 8] * 6 + 4 * [-100]
attention_mask = torch.tensor(
[[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long
[[1] * 3 * chunk_length + [0] * chunk_length, [1] * 4 * chunk_length],
device=torch_device,
dtype=torch.long,
)
input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long)
@ -390,7 +397,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
with torch.no_grad():
logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits
logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits
self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))