Fix device issue in UperNetModelIntegrationTest (#21192)

fix device

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-01-19 14:26:14 +01:00 committed by GitHub
parent 35920c9715
commit 5761ceb35a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -272,7 +272,7 @@ def prepare_img():
class UperNetModelIntegrationTest(unittest.TestCase):
def test_inference_swin_backbone(self):
processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-swin-tiny")
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-tiny")
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-tiny").to(torch_device)
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)
@ -290,7 +290,7 @@ class UperNetModelIntegrationTest(unittest.TestCase):
def test_inference_convnext_backbone(self):
processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny")
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny")
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny").to(torch_device)
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device)