diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 5146fbb89dc..e538d2b4d40 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -1311,6 +1311,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) self.sigmoid = nn.Sigmoid() + self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size + # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): # Computes normalized xy corner coordinates from feature_map. @@ -1320,6 +1322,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): device = feature_map.device num_patches = feature_map.shape[1] + # TODO: Remove numpy usage. box_coordinates = np.stack( np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 ).astype(np.float32) @@ -1432,8 +1435,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state) # Resize class token - new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) - class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out @@ -1442,8 +1444,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): # Resize to [batch_size, num_patches, num_patches, hidden_size] new_size = ( image_embeds.shape[0], - int(np.sqrt(image_embeds.shape[1])), - int(np.sqrt(image_embeds.shape[1])), + self.sqrt_num_patches, + self.sqrt_num_patches, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1466,8 +1468,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state) # Resize class token - new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) - class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out @@ -1476,8 +1477,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel): # Resize to [batch_size, num_patches, num_patches, hidden_size] new_size = ( image_embeds.shape[0], - int(np.sqrt(image_embeds.shape[1])), - int(np.sqrt(image_embeds.shape[1])), + self.sqrt_num_patches, + self.sqrt_num_patches, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index b8e8a36fec7..a06610a643b 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1292,6 +1292,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps) self.sigmoid = nn.Sigmoid() + self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size + def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): # Computes normalized xy corner coordinates from feature_map. if not feature_map.ndim == 4: @@ -1300,6 +1302,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): device = feature_map.device num_patches = feature_map.shape[1] + # TODO: Remove numpy usage. box_coordinates = np.stack( np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 ).astype(np.float32) @@ -1394,8 +1397,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) # Resize class token - new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) - class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out @@ -1404,8 +1406,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): # Resize to [batch_size, num_patches, num_patches, hidden_size] new_size = ( image_embeds.shape[0], - int(np.sqrt(image_embeds.shape[1])), - int(np.sqrt(image_embeds.shape[1])), + self.sqrt_num_patches, + self.sqrt_num_patches, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size) @@ -1427,8 +1429,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) # Resize class token - new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) - class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out @@ -1437,8 +1438,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): # Resize to [batch_size, num_patches, num_patches, hidden_size] new_size = ( image_embeds.shape[0], - int(np.sqrt(image_embeds.shape[1])), - int(np.sqrt(image_embeds.shape[1])), + self.sqrt_num_patches, + self.sqrt_num_patches, image_embeds.shape[-1], ) image_embeds = image_embeds.reshape(new_size)