mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove numpy usage from owlvit (#29326)
* remove numpy usage from owlvit * fix init owlv2 * style
This commit is contained in:
parent
ad00c482c7
commit
e715c78c66
@ -1311,6 +1311,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
|
||||
|
||||
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
|
||||
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
|
||||
# Computes normalized xy corner coordinates from feature_map.
|
||||
@ -1320,6 +1322,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
device = feature_map.device
|
||||
num_patches = feature_map.shape[1]
|
||||
|
||||
# TODO: Remove numpy usage.
|
||||
box_coordinates = np.stack(
|
||||
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
|
||||
).astype(np.float32)
|
||||
@ -1432,8 +1435,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
|
||||
|
||||
# Resize class token
|
||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||||
|
||||
# Merge image embedding with class tokens
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
@ -1442,8 +1444,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1466,8 +1468,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
|
||||
|
||||
# Resize class token
|
||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||||
|
||||
# Merge image embedding with class tokens
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
@ -1476,8 +1477,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
|
@ -1292,6 +1292,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
|
||||
|
||||
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
|
||||
# Computes normalized xy corner coordinates from feature_map.
|
||||
if not feature_map.ndim == 4:
|
||||
@ -1300,6 +1302,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
device = feature_map.device
|
||||
num_patches = feature_map.shape[1]
|
||||
|
||||
# TODO: Remove numpy usage.
|
||||
box_coordinates = np.stack(
|
||||
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
|
||||
).astype(np.float32)
|
||||
@ -1394,8 +1397,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||||
|
||||
# Resize class token
|
||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||||
|
||||
# Merge image embedding with class tokens
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
@ -1404,8 +1406,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1427,8 +1429,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||||
|
||||
# Resize class token
|
||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||||
|
||||
# Merge image embedding with class tokens
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
@ -1437,8 +1438,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
int(np.sqrt(image_embeds.shape[1])),
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
|
Loading…
Reference in New Issue
Block a user