mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix tests (#14289)
This commit is contained in:
parent
24b30d4d2f
commit
34307bb358
@ -232,7 +232,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||||
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
[self.model_tester.batch_size, height, width], device=torch_device
|
||||||
|
).long()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
@ -259,7 +261,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||||
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
elif model_class.__name__ == "BeitForSemanticSegmentation":
|
||||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||||
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
[self.model_tester.batch_size, height, width], device=torch_device
|
||||||
|
).long()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -318,7 +318,9 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
|
||||||
if model_class.__name__ == "SegformerForSemanticSegmentation":
|
if model_class.__name__ == "SegformerForSemanticSegmentation":
|
||||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||||
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
[self.model_tester.batch_size, height, width], device=torch_device
|
||||||
|
).long()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
|
Loading…
Reference in New Issue
Block a user