mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix(DPT,Depth-Anything) torch.export
(#34103)
* Fix torch.export issue in dpt based models Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Simplify the if statements Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Move activation definitions of zoe_depth to init() Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Add test_export for dpt and zoedepth Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * add depth anything Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Remove zoedepth non-automated zoedepth changes and zoedepth test Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * [run_slow] dpt, depth_anything, zoedepth Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> --------- Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>
This commit is contained in:
parent
9d16441e4f
commit
8cadf76e1c
@ -224,16 +224,16 @@ class DepthAnythingFeatureFusionStage(nn.Module):
|
||||
hidden_states = hidden_states[::-1]
|
||||
|
||||
fused_hidden_states = []
|
||||
# first layer only uses the last hidden_state
|
||||
size = hidden_states[1].shape[2:]
|
||||
fused_hidden_state = self.layers[0](hidden_states[0], size=size)
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
fused_hidden_state = None
|
||||
|
||||
# looping from the last layer to the second
|
||||
for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:])):
|
||||
size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None
|
||||
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
|
||||
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
|
||||
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
|
||||
if fused_hidden_state is None:
|
||||
# first layer only uses the last hidden_state
|
||||
fused_hidden_state = layer(hidden_state, size=size)
|
||||
else:
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
|
||||
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
|
||||
|
@ -689,12 +689,13 @@ class DPTFeatureFusionStage(nn.Module):
|
||||
hidden_states = hidden_states[::-1]
|
||||
|
||||
fused_hidden_states = []
|
||||
# first layer only uses the last hidden_state
|
||||
fused_hidden_state = self.layers[0](hidden_states[0])
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
# looping from the last layer to the second
|
||||
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
||||
fused_hidden_state = None
|
||||
for hidden_state, layer in zip(hidden_states, self.layers):
|
||||
if fused_hidden_state is None:
|
||||
# first layer only uses the last hidden_state
|
||||
fused_hidden_state = layer(hidden_state)
|
||||
else:
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
|
||||
return fused_hidden_states
|
||||
|
@ -185,12 +185,13 @@ class ZoeDepthFeatureFusionStage(nn.Module):
|
||||
hidden_states = hidden_states[::-1]
|
||||
|
||||
fused_hidden_states = []
|
||||
# first layer only uses the last hidden_state
|
||||
fused_hidden_state = self.layers[0](hidden_states[0])
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
# looping from the last layer to the second
|
||||
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
||||
fused_hidden_state = None
|
||||
for hidden_state, layer in zip(hidden_states, self.layers):
|
||||
if fused_hidden_state is None:
|
||||
# first layer only uses the last hidden_state
|
||||
fused_hidden_state = layer(hidden_state)
|
||||
else:
|
||||
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
||||
fused_hidden_states.append(fused_hidden_state)
|
||||
|
||||
return fused_hidden_states
|
||||
|
@ -18,6 +18,7 @@ import unittest
|
||||
|
||||
from transformers import DepthAnythingConfig, Dinov2Config
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -290,3 +291,30 @@ class DepthAnythingModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def test_export(self):
|
||||
for strict in [True, False]:
|
||||
with self.subTest(strict=strict):
|
||||
if not is_torch_greater_or_equal_than_2_4:
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
model = (
|
||||
DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
args=(inputs["pixel_values"],),
|
||||
strict=strict,
|
||||
)
|
||||
with torch.no_grad():
|
||||
eager_outputs = model(**inputs)
|
||||
exported_outputs = exported_program.module().forward(inputs["pixel_values"])
|
||||
self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4)
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ import unittest
|
||||
|
||||
from transformers import DPTConfig
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -410,3 +411,24 @@ class DPTModelIntegrationTest(unittest.TestCase):
|
||||
).squeeze()
|
||||
self.assertTrue(output_enlarged.shape == expected_shape)
|
||||
self.assertTrue(torch.allclose(predicted_depth_l, output_enlarged, rtol=1e-3))
|
||||
|
||||
def test_export(self):
|
||||
for strict in [True, False]:
|
||||
with self.subTest(strict=strict):
|
||||
if not is_torch_greater_or_equal_than_2_4:
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device).eval()
|
||||
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
args=(inputs["pixel_values"],),
|
||||
strict=strict,
|
||||
)
|
||||
with torch.no_grad():
|
||||
eager_outputs = model(**inputs)
|
||||
exported_outputs = exported_program.module().forward(inputs["pixel_values"])
|
||||
self.assertEqual(eager_outputs.logits.shape, exported_outputs.logits.shape)
|
||||
self.assertTrue(torch.allclose(eager_outputs.logits, exported_outputs.logits, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user