mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[BigBird Pegasus] Make tests faster (#11744)
* improve tests * remove bogus file * make style Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
a0531c8a24
commit
73893fc771
@ -368,17 +368,24 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1)
|
||||
|
||||
def _check_batched_forward(self, attn_type, tolerance=1e-3):
|
||||
config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type)
|
||||
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||
config.max_position_embeddings = 128
|
||||
config.block_size = 16
|
||||
config.attention_type = attn_type
|
||||
model = BigBirdPegasusForConditionalGeneration(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
sample_with_padding = [3, 8, 11] * 128 + [0] * 128
|
||||
sample_without_padding = [4, 7, 9, 13] * 128
|
||||
chunk_length = 32
|
||||
|
||||
sample_with_padding = [3, 8, 11] * chunk_length + [0] * chunk_length
|
||||
sample_without_padding = [4, 7, 9, 13] * chunk_length
|
||||
target_ids_without_padding = [2, 3] * 8
|
||||
target_ids_with_padding = [7, 8] * 6 + 4 * [-100]
|
||||
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long
|
||||
[[1] * 3 * chunk_length + [0] * chunk_length, [1] * 4 * chunk_length],
|
||||
device=torch_device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long)
|
||||
@ -390,7 +397,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
|
||||
|
||||
with torch.no_grad():
|
||||
logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits
|
||||
logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits
|
||||
|
||||
self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user