mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Clean up semantic segmentation tests (#16801)
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
989a15d173
commit
494c2a8c4d
@ -244,13 +244,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||
continue
|
||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
[self.model_tester.batch_size, height, width], device=torch_device
|
||||
).long()
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
@ -316,13 +316,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class in get_values(MODEL_MAPPING):
|
||||
continue
|
||||
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||
if model_class.__name__ == "SegformerForSemanticSegmentation":
|
||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
[self.model_tester.batch_size, height, width], device=torch_device
|
||||
).long()
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
Loading…
Reference in New Issue
Block a user