Remove numpy usage from owlvit (#29326)

* remove numpy usage from owlvit

* fix init owlv2

* style
This commit is contained in:
fxmarty 2024-02-28 09:38:44 +01:00 committed by GitHub
parent ad00c482c7
commit e715c78c66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 16 deletions

View File

@ -1311,6 +1311,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
@ -1320,6 +1322,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
device = feature_map.device
num_patches = feature_map.shape[1]
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
@ -1432,8 +1435,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
@ -1442,8 +1444,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
self.sqrt_num_patches,
self.sqrt_num_patches,
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
@ -1466,8 +1468,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
@ -1476,8 +1477,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
self.sqrt_num_patches,
self.sqrt_num_patches,
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)

View File

@ -1292,6 +1292,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
self.sigmoid = nn.Sigmoid()
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
@ -1300,6 +1302,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
device = feature_map.device
num_patches = feature_map.shape[1]
# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
@ -1394,8 +1397,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
@ -1404,8 +1406,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
self.sqrt_num_patches,
self.sqrt_num_patches,
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
@ -1427,8 +1429,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
@ -1437,8 +1438,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
self.sqrt_num_patches,
self.sqrt_num_patches,
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)