mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix
This commit is contained in:
parent
e23c242848
commit
fce61367b5
@ -21,6 +21,7 @@ import numpy as np
|
|||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
|
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
require_timm,
|
require_timm,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@ -526,13 +527,21 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
rtol=TOLERANCE,
|
rtol=TOLERANCE,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_slice_hidden_state = torch.tensor(
|
expectations = Expectations(
|
||||||
[
|
{
|
||||||
[-0.8422, -0.8435, -0.9717],
|
(None, None): [
|
||||||
[-1.0145, -0.5564, -0.4195],
|
[-0.8422, -0.8434, -0.9718],
|
||||||
[-1.0040, -0.4486, -0.1962],
|
[-1.0144, -0.5565, -0.4195],
|
||||||
]
|
[-1.0038, -0.4484, -0.1961]
|
||||||
).to(torch_device)
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[-0.8422, -0.8435, -0.9717],
|
||||||
|
[-1.0145, -0.5564, -0.4195],
|
||||||
|
[-1.0040, -0.4486, -0.1962],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3],
|
outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3],
|
||||||
expected_slice_hidden_state,
|
expected_slice_hidden_state,
|
||||||
@ -540,13 +549,21 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
rtol=TOLERANCE,
|
rtol=TOLERANCE,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_slice_hidden_state = torch.tensor(
|
expectations = Expectations(
|
||||||
[
|
{
|
||||||
[0.2853, -0.0162, 0.9736],
|
(None, None): [
|
||||||
[0.6256, 0.1856, 0.8530],
|
[0.2852, -0.0159, 0.9735],
|
||||||
[-0.0679, -0.4118, 1.8416],
|
[0.6254, 0.1858, 0.8529],
|
||||||
]
|
[-0.0680, -0.4116, 1.8413],
|
||||||
).to(torch_device)
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[0.2853, -0.0162, 0.9736],
|
||||||
|
[0.6256, 0.1856, 0.8530],
|
||||||
|
[-0.0679, -0.4118, 1.8416],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
outputs.transformer_decoder_last_hidden_state[0, :3, :3],
|
outputs.transformer_decoder_last_hidden_state[0, :3, :3],
|
||||||
expected_slice_hidden_state,
|
expected_slice_hidden_state,
|
||||||
@ -577,25 +594,42 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
masks_queries_logits.shape,
|
masks_queries_logits.shape,
|
||||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||||
)
|
)
|
||||||
expected_slice = [
|
expectations = Expectations(
|
||||||
[-1.3737, -1.7727, -1.9367],
|
{
|
||||||
[-1.5979, -1.9871, -2.1527],
|
(None, None): [
|
||||||
[-1.5797, -1.9271, -2.0941],
|
[-1.3737124, -1.7724937, -1.9364233],
|
||||||
]
|
[-1.5977281, -1.9867939, -2.1523695],
|
||||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
[-1.5795398, -1.9269832, -2.093942],
|
||||||
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[-1.3737, -1.7727, -1.9367],
|
||||||
|
[-1.5979, -1.9871, -2.1527],
|
||||||
|
[-1.5797, -1.9271, -2.0941],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||||
# class_queries_logits
|
# class_queries_logits
|
||||||
class_queries_logits = outputs.class_queries_logits
|
class_queries_logits = outputs.class_queries_logits
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||||
)
|
)
|
||||||
expected_slice = torch.tensor(
|
expectations = Expectations(
|
||||||
[
|
{
|
||||||
[1.6507e00, -5.2568e00, -3.3520e00],
|
(None, None): [
|
||||||
[3.5767e-02, -5.9023e00, -2.9313e00],
|
[1.6512e00, -5.2572e00, -3.3519e00],
|
||||||
[-6.2712e-04, -7.7627e00, -5.1268e00],
|
[3.6169e-02, -5.9025e00, -2.9313e00],
|
||||||
]
|
[1.0766e-04, -7.7630e00, -5.1263e00],
|
||||||
).to(torch_device)
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[1.6507e00, -5.2568e00, -3.3520e00],
|
||||||
|
[3.5767e-02, -5.9023e00, -2.9313e00],
|
||||||
|
[-6.2712e-04, -7.7627e00, -5.1268e00],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
|
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
|
||||||
)
|
)
|
||||||
@ -623,25 +657,42 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
masks_queries_logits.shape,
|
masks_queries_logits.shape,
|
||||||
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
|
||||||
)
|
)
|
||||||
expected_slice = [
|
expectations = Expectations(
|
||||||
[-0.9000, -2.6283, -4.5964],
|
{
|
||||||
[-3.4123, -5.7789, -8.7919],
|
(None, None): [
|
||||||
[-4.9132, -7.6444, -10.7557],
|
[-0.9046, -2.6366, -4.6062],
|
||||||
]
|
[-3.4179, -5.7890, -8.8057],
|
||||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
[-4.9179, -7.6560, -10.7711]
|
||||||
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[-0.9000, -2.6283, -4.5964],
|
||||||
|
[-3.4123, -5.7789, -8.7919],
|
||||||
|
[-4.9132, -7.6444, -10.7557],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||||
# class_queries_logits
|
# class_queries_logits
|
||||||
class_queries_logits = outputs.class_queries_logits
|
class_queries_logits = outputs.class_queries_logits
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
|
||||||
)
|
)
|
||||||
expected_slice = torch.tensor(
|
expectations = Expectations(
|
||||||
[
|
{
|
||||||
[4.7177, -3.2586, -2.8853],
|
(None, None): [
|
||||||
[6.6845, -2.9186, -1.2491],
|
[-0.9000, -2.6283, -4.5964],
|
||||||
[7.2443, -2.2760, -2.1858],
|
[-3.4123, -5.7789, -8.7919],
|
||||||
]
|
[-4.9132, -7.6444, -10.7557],
|
||||||
).to(torch_device)
|
],
|
||||||
|
("cuda", 8): [
|
||||||
|
[4.7177, -3.2586, -2.8853],
|
||||||
|
[6.6845, -2.9186, -1.2491],
|
||||||
|
[7.2443, -2.2760, -2.1858],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
|
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user