fix: repair depth estimation multiprocessing (#33759)

* fix: repair depth estimation multiprocessing

* test: add test for multiprocess depth estimation
This commit is contained in:
Nicola De Angeli 2024-10-01 18:59:59 +02:00 committed by GitHub
parent f205da9660
commit 0256520794
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 2 deletions

View File

@ -89,20 +89,22 @@ class DepthEstimationPipeline(Pipeline):
def preprocess(self, image, timeout=None):
image = load_image(image, timeout)
self.image_size = image.size
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
model_inputs["target_size"] = image.size[::-1]
return model_inputs
def _forward(self, model_inputs):
target_size = model_inputs.pop("target_size")
model_outputs = self.model(**model_inputs)
model_outputs["target_size"] = target_size
return model_outputs
def postprocess(self, model_outputs):
predicted_depth = model_outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1), size=self.image_size[::-1], mode="bicubic", align_corners=False
predicted_depth.unsqueeze(1), size=model_outputs["target_size"], mode="bicubic", align_corners=False
)
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")

View File

@ -116,3 +116,23 @@ class DepthEstimationPipelineTests(unittest.TestCase):
def test_small_model_pt(self):
# This is highly irregular to have no small tests.
self.skipTest(reason="There is not hf-internal-testing tiny model for either GLPN nor DPT")
@require_torch
def test_multiprocess(self):
depth_estimator = pipeline(
model="hf-internal-testing/tiny-random-DepthAnythingForDepthEstimation",
num_workers=2,
)
outputs = depth_estimator(
[
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
]
)
self.assertEqual(
[
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
],
outputs,
)