mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix: repair depth estimation multiprocessing (#33759)
* fix: repair depth estimation multiprocessing * test: add test for multiprocess depth estimation
This commit is contained in:
parent
f205da9660
commit
0256520794
@ -89,20 +89,22 @@ class DepthEstimationPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
image = load_image(image, timeout)
|
||||
self.image_size = image.size
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
if self.framework == "pt":
|
||||
model_inputs = model_inputs.to(self.torch_dtype)
|
||||
model_inputs["target_size"] = image.size[::-1]
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
target_size = model_inputs.pop("target_size")
|
||||
model_outputs = self.model(**model_inputs)
|
||||
model_outputs["target_size"] = target_size
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
predicted_depth = model_outputs.predicted_depth
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(1), size=self.image_size[::-1], mode="bicubic", align_corners=False
|
||||
predicted_depth.unsqueeze(1), size=model_outputs["target_size"], mode="bicubic", align_corners=False
|
||||
)
|
||||
output = prediction.squeeze().cpu().numpy()
|
||||
formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
|
@ -116,3 +116,23 @@ class DepthEstimationPipelineTests(unittest.TestCase):
|
||||
def test_small_model_pt(self):
|
||||
# This is highly irregular to have no small tests.
|
||||
self.skipTest(reason="There is not hf-internal-testing tiny model for either GLPN nor DPT")
|
||||
|
||||
@require_torch
|
||||
def test_multiprocess(self):
|
||||
depth_estimator = pipeline(
|
||||
model="hf-internal-testing/tiny-random-DepthAnythingForDepthEstimation",
|
||||
num_workers=2,
|
||||
)
|
||||
outputs = depth_estimator(
|
||||
[
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
],
|
||||
outputs,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user