diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 40b01d34a8f..711e688e564 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -16,6 +16,7 @@ import collections.abc import math +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -196,12 +197,16 @@ class BeitEmbeddings(nn.Module): self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, + interpolate_pos_encoding: Optional[bool] = None, ) -> torch.Tensor: + if self.position_embeddings is not None and interpolate_pos_encoding is not None: + warnings.warn( + "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always " + "interpolated to the input image size. The argument will be removed in transformers v4.51.0." + ) + _, _, height, width = pixel_values.shape - embeddings, (patch_height, patch_width) = self.patch_embeddings( - pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None - ) + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: @@ -211,14 +216,11 @@ class BeitEmbeddings(nn.Module): embeddings = embeddings * (1 - w) + mask_tokens * w cls_tokens = self.cls_token.expand(batch_size, -1, -1) - if self.position_embeddings is not None: - if interpolate_pos_encoding: - cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) - else: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] - embeddings = torch.cat((cls_tokens, embeddings), dim=1) + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + embeddings = self.dropout(embeddings) return embeddings, (patch_height, patch_width) @@ -248,11 +250,7 @@ class BeitPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward( - self, - pixel_values: torch.Tensor, - position_embedding: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -261,17 +259,6 @@ class BeitPatchEmbeddings(nn.Module): embeddings = self.projection(pixel_values) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] - - if position_embedding is not None: - # interpolate the position embedding to the corresponding size - position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute( - 0, 3, 1, 2 - ) - position_embedding = nn.functional.interpolate( - position_embedding, size=(patch_height, patch_width), mode="bicubic" - ) - embeddings = embeddings + position_embedding - embeddings = embeddings.flatten(2).transpose(1, 2) return embeddings, (patch_height, patch_width) @@ -887,9 +874,7 @@ class BeitModel(BeitPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, _ = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding - ) + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) resolution = pixel_values.shape[2:] encoder_outputs = self.encoder( diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 8e4d9c0bb26..f6dba8235df 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -16,6 +16,7 @@ import collections.abc import math +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -195,12 +196,16 @@ class Data2VecVisionEmbeddings(nn.Module): self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, + interpolate_pos_encoding: Optional[bool] = None, ) -> torch.Tensor: + if self.position_embeddings is not None and interpolate_pos_encoding is not None: + warnings.warn( + "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always " + "interpolated to the input image size. The argument will be removed in transformers v4.51.0." + ) + _, _, height, width = pixel_values.shape - embeddings, (patch_height, patch_width) = self.patch_embeddings( - pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None - ) + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: @@ -210,14 +215,11 @@ class Data2VecVisionEmbeddings(nn.Module): embeddings = embeddings * (1 - w) + mask_tokens * w cls_tokens = self.cls_token.expand(batch_size, -1, -1) - if self.position_embeddings is not None: - if interpolate_pos_encoding: - cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width) - else: - cls_tokens = cls_tokens + self.position_embeddings[:, :1, :] - embeddings = torch.cat((cls_tokens, embeddings), dim=1) + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + embeddings = self.dropout(embeddings) return embeddings, (patch_height, patch_width) @@ -248,11 +250,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward( - self, - pixel_values: torch.Tensor, - position_embedding: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -261,17 +259,6 @@ class Data2VecVisionPatchEmbeddings(nn.Module): embeddings = self.projection(pixel_values) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] - - if position_embedding is not None: - # interpolate the position embedding to the corresponding size - position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute( - 0, 3, 1, 2 - ) - position_embedding = nn.functional.interpolate( - position_embedding, size=(patch_height, patch_width), mode="bicubic" - ) - embeddings = embeddings + position_embedding - embeddings = embeddings.flatten(2).transpose(1, 2) return embeddings, (patch_height, patch_width) @@ -902,9 +889,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, _ = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding - ) + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) resolution = pixel_values.shape[2:] encoder_outputs = self.encoder( diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 89245c7009b..c455c9eebb1 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -774,7 +774,9 @@ class BeitModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(pixel_values, interpolate_pos_encoding=True) - expected_shape = torch.Size((1, 1801, 768)) + # num_cls_tokens + (height / patch_size) * (width / patch_size) + # 1 + (480 / 16) * (480 / 16) = 1 + 30 * 30 = 901 + expected_shape = torch.Size((1, 901, 768)) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 0a9d1fd1812..f297d3a3c6d 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -565,17 +565,12 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase): inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480}) pixel_values = inputs.pixel_values.to(torch_device) - # with interpolate_pos_encoding being False an exception should be raised with higher resolution - # images than what the model supports. - self.assertFalse(processor.do_center_crop) - with torch.no_grad(): - with self.assertRaises(ValueError, msg="doesn't match model"): - model(pixel_values, interpolate_pos_encoding=False) - # with interpolate_pos_encoding being True the model should process the higher resolution image # successfully and produce the expected output. with torch.no_grad(): outputs = model(pixel_values, interpolate_pos_encoding=True) - expected_shape = torch.Size((1, 1801, 768)) + # num_cls_tokens + (height / patch_size) * (width / patch_size) + # 1 + (480 / 16) * (480 / 16) = 901 + expected_shape = torch.Size((1, 901, 768)) self.assertEqual(outputs.last_hidden_state.shape, expected_shape) diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index bfeded558b6..e92cc6ddc28 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -684,7 +684,12 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase): self.assertTrue(results["segmentation"].shape, expected_shape) torch.testing.assert_close(results["segmentation"][:3, :3], expected_slice_segmentation, rtol=1e-4, atol=1e-4) self.assertTrue(len(results["segments_info"]), expected_number_of_segments) - self.assertDictEqual(results["segments_info"][0], expected_first_segment) + + predicted_first_segment = results["segments_info"][0] + self.assertEqual(predicted_first_segment["id"], expected_first_segment["id"]) + self.assertEqual(predicted_first_segment["label_id"], expected_first_segment["label_id"]) + self.assertEqual(predicted_first_segment["was_fused"], expected_first_segment["was_fused"]) + self.assertAlmostEqual(predicted_first_segment["score"], expected_first_segment["score"], places=3) @require_vision diff --git a/tests/models/dinov2/test_modeling_dinov2.py b/tests/models/dinov2/test_modeling_dinov2.py index 5cbcbe77d90..3e52ad49af3 100644 --- a/tests/models/dinov2/test_modeling_dinov2.py +++ b/tests/models/dinov2/test_modeling_dinov2.py @@ -329,10 +329,10 @@ class Dinov2ModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.last_hidden_state.shape, expected_shape) expected_slice = torch.tensor( - [[-2.1747, -0.4729, 1.0936], [-3.2780, -0.8269, -0.9210], [-2.9129, 1.1284, -0.7306]], + [[-2.2005, -0.4495, 1.0964], [-3.3959, -0.8942, -1.0315], [-2.9355, 1.1564, -0.7656]], device=torch_device, ) - torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3) @require_torch diff --git a/tests/models/textnet/test_modeling_textnet.py b/tests/models/textnet/test_modeling_textnet.py index 5d560f919b1..1c88c95cc8b 100644 --- a/tests/models/textnet/test_modeling_textnet.py +++ b/tests/models/textnet/test_modeling_textnet.py @@ -328,14 +328,18 @@ class TextNetModelIntegrationTest(unittest.TestCase): with torch.no_grad(): output = model(**inputs) - # verify logits - self.assertEqual(output.logits.shape, torch.Size([1, 2])) + # verify output + self.assertEqual(output.last_hidden_state.shape, torch.Size([1, 512, 20, 27])) expected_slice_backbone = torch.tensor( - [0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000], + [ + [0.0000, 1.7415, 1.2660], + [0.0000, 1.0084, 1.9692], + [0.0000, 1.7464, 1.7892], + ], device=torch_device, ) torch.testing.assert_close( - output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, rtol=1e-3, atol=1e-3 + output.last_hidden_state[0, 12, :3, :3], expected_slice_backbone, rtol=1e-2, atol=1e-2 ) diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 929d7fec959..32812e33280 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -310,10 +310,10 @@ class ViTModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.last_hidden_state.shape, expected_shape) expected_slice = torch.tensor( - [[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]] + [[4.2325, 4.3882, -6.6678], [4.5372, 1.8933, -6.7355], [4.4454, 0.8514, -5.8747]] ).to(torch_device) - torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3) @slow @require_accelerate diff --git a/tests/models/zoedepth/test_modeling_zoedepth.py b/tests/models/zoedepth/test_modeling_zoedepth.py index 782b99af436..342ae269d39 100644 --- a/tests/models/zoedepth/test_modeling_zoedepth.py +++ b/tests/models/zoedepth/test_modeling_zoedepth.py @@ -301,8 +301,8 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase): out_l_reduced = torch.nn.functional.interpolate( out_l.unsqueeze(0).unsqueeze(1), size=img.size[::-1], mode="bicubic", align_corners=False ) - self.assertTrue((np.array(out_l.shape)[::-1] == np.array(img.size) * 2).all()) - torch.testing.assert_close(out, out_l_reduced, rtol=2e-2) + out_l_reduced = out_l_reduced.squeeze(0).squeeze(0) + torch.testing.assert_close(out, out_l_reduced, rtol=2e-2, atol=2e-2) def check_post_processing_test(self, image_processor, images, model, pad_input=True, flip_aug=True): inputs = image_processor(images=images, return_tensors="pt", do_pad=pad_input).to(torch_device) @@ -324,7 +324,7 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase): for img, out, expected_slice in zip(images, outputs, expected_slices): out = out["predicted_depth"] self.assertTrue(img.size == out.shape[::-1]) - torch.testing.assert_close(expected_slice, out[:3, :3], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(expected_slice, out[:3, :3], rtol=1e-3, atol=1e-3) self.check_target_size(image_processor, pad_input, images, outputs, raw_outputs, raw_outputs_flipped)