mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix dtype casting in swinv2 and swinv2sr to allow non-FP32 inference (#31589)
* Fix dtype casting in modeling_swin2sr to allow non-FP32 inference * Fix formattting * Fix for swinv2 too * Update src/transformers/models/swin2sr/modeling_swin2sr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/swinv2/modeling_swinv2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add FP16 tests for swin2sr and swinv2 * [run_slow] swin2sr, swinv2 * [run_slow] swin2sr, swinv2 --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
a3fb96a42a
commit
1f9f57ab4c
@ -301,6 +301,8 @@ class Swin2SRSelfAttention(nn.Module):
|
||||
relative_coords_table = (
|
||||
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
||||
)
|
||||
# set to same dtype as mlp weight
|
||||
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
|
||||
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
|
@ -492,6 +492,8 @@ class Swinv2SelfAttention(nn.Module):
|
||||
relative_coords_table = (
|
||||
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
||||
)
|
||||
# set to same dtype as mlp weight
|
||||
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
|
||||
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
|
@ -333,3 +333,24 @@ class Swin2SRModelIntegrationTest(unittest.TestCase):
|
||||
[[0.5458, 0.5546, 0.5638], [0.5526, 0.5565, 0.5651], [0.5396, 0.5426, 0.5621]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def test_inference_fp16(self):
|
||||
processor = Swin2SRImageProcessor()
|
||||
model = Swin2SRForImageSuperResolution.from_pretrained(
|
||||
"caidas/swin2SR-classical-sr-x2-64", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = processor(images=image, return_tensors="pt").to(model.dtype).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 3, 976, 1296])
|
||||
self.assertEqual(outputs.reconstruction.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], dtype=model.dtype
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
@ -487,6 +487,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_fp16(self):
|
||||
model = Swinv2ForImageClassification.from_pretrained(
|
||||
"microsoft/swinv2-tiny-patch4-window8-256", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
image_processor = self.default_image_processor
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(model.dtype).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
expected_slice = torch.tensor([-0.3938, -0.4290, 0.0020], dtype=model.dtype).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
|
Loading…
Reference in New Issue
Block a user