mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix batching tests for new models (Mamba and SegGPT) (#29633)
* fix batchinng tests for new models * Update tests/models/seggpt/test_modeling_seggpt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
31d01150ad
commit
5ac264d8a8
@ -245,6 +245,60 @@ class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
def recursive_check(batched_object, single_row_object, model_name, key):
|
||||
if isinstance(batched_object, (list, tuple)):
|
||||
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
else:
|
||||
batched_row = batched_object[:1]
|
||||
self.assertFalse(
|
||||
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.max(torch.abs(batched_row - single_row_object)) <= 1e-03,
|
||||
msg=(
|
||||
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
||||
f"Difference={torch.max(torch.abs(batched_row - single_row_object))}."
|
||||
),
|
||||
)
|
||||
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
|
||||
model_name = model_class.__name__
|
||||
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
batch_size = self.model_tester.batch_size
|
||||
single_row_input = {}
|
||||
for key, value in batched_input_prepared.items():
|
||||
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
|
||||
single_batch_shape = value.shape[0] // batch_size
|
||||
single_row_input[key] = value[:single_batch_shape]
|
||||
|
||||
with torch.no_grad():
|
||||
model_batched_output = model(**batched_input_prepared)
|
||||
model_row_output = model(**single_row_input)
|
||||
|
||||
for key in model_batched_output:
|
||||
# the first hidden state in SegGPT has weird hack of adding first half of batch with second half
|
||||
if key == "hidden_states":
|
||||
model_batched_output[key] = model_batched_output[key][1:]
|
||||
model_row_output[key] = model_row_output[key][1:]
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
@ -720,8 +720,8 @@ class ModelTesterMixin:
|
||||
batched_object.values(), single_row_object.values()
|
||||
):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
# do not compare returned loss (0-dim tensor) or codebook ids (int)
|
||||
elif batched_object is None or isinstance(batched_object, int):
|
||||
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
||||
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
||||
return
|
||||
elif batched_object.dim() == 0:
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user