mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
owlvit/2 dynamic input resolution (#34764)
* owlvit/2 dynamic input resolution. * adapt box grid to patch_dim_h patch_dim_w * fix ci * clarify variable naming * clarify variable naming.. * compute box_bias dynamically inside box_predictor * change style part of code * [run-slow] owlvit, owlv2
This commit is contained in:
parent
608e163b52
commit
8f38f58f3d
@ -33,6 +33,7 @@ from ...utils import (
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_owlv2 import Owlv2Config, Owlv2TextConfig, Owlv2VisionConfig
|
||||
|
||||
@ -274,6 +275,7 @@ class Owlv2ImageGuidedObjectDetectionOutput(ModelOutput):
|
||||
class Owlv2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Owlv2VisionConfig):
|
||||
super().__init__()
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
|
||||
@ -291,15 +293,59 @@ class Owlv2VisionEmbeddings(nn.Module):
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
||||
num_positions = position_embedding.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
class_pos_embed = position_embedding[:, :1]
|
||||
patch_pos_embed = position_embedding[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -610,6 +656,8 @@ OWLV2_VISION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -635,6 +683,8 @@ OWLV2_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_base_image_embeds (`bool`, *optional*):
|
||||
Whether or not to return the base image embeddings.
|
||||
return_dict (`bool`, *optional*):
|
||||
@ -657,6 +707,8 @@ OWLV2_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
|
||||
`vision_model_last_hidden_state` under returned tensors for more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -673,6 +725,8 @@ OWLV2_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -914,6 +968,7 @@ class Owlv2VisionTransformer(nn.Module):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -929,7 +984,7 @@ class Owlv2VisionTransformer(nn.Module):
|
||||
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
|
||||
pixel_values = pixel_values.to(expected_input_dtype)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
@ -976,6 +1031,7 @@ class Owlv2VisionModel(Owlv2PreTrainedModel):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -1002,6 +1058,7 @@ class Owlv2VisionModel(Owlv2PreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1084,6 +1141,7 @@ class Owlv2Model(Owlv2PreTrainedModel):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
@ -1115,6 +1173,7 @@ class Owlv2Model(Owlv2PreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1133,6 +1192,7 @@ class Owlv2Model(Owlv2PreTrainedModel):
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_base_image_embeds: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Owlv2Output]:
|
||||
@ -1165,6 +1225,7 @@ class Owlv2Model(Owlv2PreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1295,21 +1356,23 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
|
||||
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
|
||||
self.config = config
|
||||
self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
|
||||
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
|
||||
def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor:
|
||||
# Create grid coordinates using torch
|
||||
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||||
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||||
x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32)
|
||||
y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32)
|
||||
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
|
||||
|
||||
# Stack the coordinates and divide by num_patches
|
||||
# Stack the coordinates and divide by their respective patch counts
|
||||
box_coordinates = torch.stack((xx, yy), dim=-1)
|
||||
box_coordinates /= num_patches
|
||||
box_coordinates[..., 0] /= num_patches_width
|
||||
box_coordinates[..., 1] /= num_patches_height
|
||||
|
||||
# Flatten (h, w, 2) -> (h*w, 2)
|
||||
box_coordinates = box_coordinates.view(-1, 2)
|
||||
@ -1332,18 +1395,22 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
|
||||
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
|
||||
def compute_box_bias(
|
||||
self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None
|
||||
) -> torch.Tensor:
|
||||
if feature_map is not None:
|
||||
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
|
||||
# The box center is biased to its position on the feature grid
|
||||
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
|
||||
box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width)
|
||||
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
|
||||
|
||||
# Unnormalize xy
|
||||
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
|
||||
|
||||
# The box size is biased to the patch size
|
||||
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
|
||||
box_size = torch.full_like(box_coord_bias, 1.0)
|
||||
box_size[..., 0] /= num_patches_width
|
||||
box_size[..., 1] /= num_patches_height
|
||||
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
|
||||
|
||||
# Compute box bias
|
||||
@ -1355,6 +1422,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
self,
|
||||
image_feats: torch.FloatTensor,
|
||||
feature_map: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1362,6 +1430,8 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
Features extracted from the image, returned by the `image_text_embedder` method.
|
||||
feature_map:
|
||||
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
|
||||
interpolate_pos_encoding:
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
Returns:
|
||||
pred_boxes:
|
||||
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
|
||||
@ -1370,7 +1440,13 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
pred_boxes = self.box_head(image_feats)
|
||||
|
||||
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
|
||||
box_bias = self.box_bias.to(feature_map.device)
|
||||
if interpolate_pos_encoding:
|
||||
_, num_patches_height, num_patches_width, _ = feature_map.shape
|
||||
box_bias = self.compute_box_bias(num_patches_height, num_patches_width)
|
||||
else:
|
||||
box_bias = self.box_bias
|
||||
|
||||
box_bias = box_bias.to(feature_map.device)
|
||||
pred_boxes += box_bias
|
||||
pred_boxes = self.sigmoid(pred_boxes)
|
||||
return pred_boxes
|
||||
@ -1403,6 +1479,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
# Encode text and image
|
||||
outputs = self.owlv2(
|
||||
@ -1411,9 +1488,18 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_height = height // self.config.vision_config.patch_size
|
||||
num_patches_width = width // self.config.vision_config.patch_size
|
||||
else:
|
||||
num_patches_height = self.num_patches_height
|
||||
num_patches_width = self.num_patches_width
|
||||
|
||||
# Get image embeddings
|
||||
last_hidden_state = outputs.vision_model_output[0]
|
||||
image_embeds = self.owlv2.vision_model.post_layernorm(last_hidden_state)
|
||||
@ -1425,11 +1511,11 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
image_embeds = self.layer_norm(image_embeds)
|
||||
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
# Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
num_patches_height,
|
||||
num_patches_width,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1443,9 +1529,20 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
# Get Owlv2Model vision embeddings (same as CLIP)
|
||||
vision_outputs = self.owlv2.vision_model(pixel_values=pixel_values, return_dict=True)
|
||||
vision_outputs = self.owlv2.vision_model(
|
||||
pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
|
||||
)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_height = height // self.config.vision_config.patch_size
|
||||
num_patches_width = width // self.config.vision_config.patch_size
|
||||
else:
|
||||
num_patches_height = self.num_patches_height
|
||||
num_patches_width = self.num_patches_width
|
||||
|
||||
# Apply post_layernorm to last_hidden_state, return non-projected output
|
||||
last_hidden_state = vision_outputs[0]
|
||||
@ -1458,11 +1555,11 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
image_embeds = self.layer_norm(image_embeds)
|
||||
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
# Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
num_patches_height,
|
||||
num_patches_width,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1471,10 +1568,13 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.embed_image_query
|
||||
def embed_image_query(
|
||||
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
|
||||
self,
|
||||
query_image_features: torch.FloatTensor,
|
||||
query_feature_map: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
_, class_embeds = self.class_predictor(query_image_features)
|
||||
pred_boxes = self.box_predictor(query_image_features, query_feature_map)
|
||||
pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding)
|
||||
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
|
||||
|
||||
# Loop over query images
|
||||
@ -1519,6 +1619,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
query_pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Owlv2ImageGuidedObjectDetectionOutput:
|
||||
r"""
|
||||
@ -1576,26 +1677,33 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# Compute feature maps for the input and query images
|
||||
query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0]
|
||||
query_feature_map = self.image_embedder(
|
||||
pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)[0]
|
||||
feature_map, vision_outputs = self.image_embedder(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
|
||||
query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
|
||||
query_image_feats = torch.reshape(
|
||||
query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)
|
||||
)
|
||||
# Get top class embedding and best box index for each query image in batch
|
||||
query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
|
||||
query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(
|
||||
query_image_feats, query_feature_map, interpolate_pos_encoding
|
||||
)
|
||||
|
||||
# Predict object classes [batch_size, num_patches, num_queries+1]
|
||||
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
|
||||
|
||||
# Predict object boxes
|
||||
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
||||
target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
@ -1630,6 +1738,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Owlv2ObjectDetectionOutput:
|
||||
r"""
|
||||
@ -1683,14 +1792,15 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
# Text and vision model outputs
|
||||
text_outputs = outputs.text_model_output
|
||||
vision_outputs = outputs.vision_model_output
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
|
||||
|
||||
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
|
||||
max_text_queries = input_ids.shape[0] // batch_size
|
||||
@ -1707,7 +1817,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
|
||||
objectness_logits = self.objectness_predictor(image_feats)
|
||||
|
||||
# Predict object boxes
|
||||
pred_boxes = self.box_predictor(image_feats, feature_map)
|
||||
pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
|
@ -33,6 +33,7 @@ from ...utils import (
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||||
|
||||
@ -268,6 +269,7 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
|
||||
class OwlViTVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: OwlViTVisionConfig):
|
||||
super().__init__()
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
|
||||
@ -285,15 +287,55 @@ class OwlViTVisionEmbeddings(nn.Module):
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
||||
num_positions = position_embedding.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
class_pos_embed = position_embedding[:, :1]
|
||||
patch_pos_embed = position_embedding[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -601,6 +643,8 @@ OWLVIT_VISION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -626,6 +670,8 @@ OWLVIT_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -646,6 +692,8 @@ OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
|
||||
`vision_model_last_hidden_state` under returned tensors for more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -662,6 +710,8 @@ OWLVIT_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -899,6 +949,7 @@ class OwlViTVisionTransformer(nn.Module):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -914,7 +965,7 @@ class OwlViTVisionTransformer(nn.Module):
|
||||
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
|
||||
pixel_values = pixel_values.to(expected_input_dtype)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
@ -960,6 +1011,7 @@ class OwlViTVisionModel(OwlViTPreTrainedModel):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -986,6 +1038,7 @@ class OwlViTVisionModel(OwlViTPreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1067,6 +1120,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
@ -1098,6 +1152,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1116,6 +1171,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
||||
return_loss: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_base_image_embeds: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, OwlViTOutput]:
|
||||
@ -1148,6 +1204,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1275,20 +1332,22 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
|
||||
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
|
||||
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
|
||||
self.config = config
|
||||
self.num_patches_height = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
self.num_patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
self.box_bias = self.compute_box_bias(self.num_patches_height, self.num_patches_width)
|
||||
|
||||
@staticmethod
|
||||
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
|
||||
def normalize_grid_corner_coordinates(num_patches_height: int, num_patches_width: int) -> torch.Tensor:
|
||||
# Create grid coordinates using torch
|
||||
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||||
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||||
x_coordinates = torch.arange(1, num_patches_width + 1, dtype=torch.float32)
|
||||
y_coordinates = torch.arange(1, num_patches_height + 1, dtype=torch.float32)
|
||||
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
|
||||
|
||||
# Stack the coordinates and divide by num_patches
|
||||
# Stack the coordinates and divide by their respective patch counts
|
||||
box_coordinates = torch.stack((xx, yy), dim=-1)
|
||||
box_coordinates /= num_patches
|
||||
box_coordinates[..., 0] /= num_patches_width
|
||||
box_coordinates[..., 1] /= num_patches_height
|
||||
|
||||
# Flatten (h, w, 2) -> (h*w, 2)
|
||||
box_coordinates = box_coordinates.view(-1, 2)
|
||||
@ -1296,18 +1355,22 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
return box_coordinates
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
|
||||
def compute_box_bias(
|
||||
self, num_patches_height: int, num_patches_width: int, feature_map: Optional[torch.FloatTensor] = None
|
||||
) -> torch.Tensor:
|
||||
if feature_map is not None:
|
||||
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
|
||||
# The box center is biased to its position on the feature grid
|
||||
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
|
||||
box_coordinates = self.normalize_grid_corner_coordinates(num_patches_height, num_patches_width)
|
||||
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
|
||||
|
||||
# Unnormalize xy
|
||||
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
|
||||
|
||||
# The box size is biased to the patch size
|
||||
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
|
||||
box_size = torch.full_like(box_coord_bias, 1.0)
|
||||
box_size[..., 0] /= num_patches_width
|
||||
box_size[..., 1] /= num_patches_height
|
||||
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
|
||||
|
||||
# Compute box bias
|
||||
@ -1318,6 +1381,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
self,
|
||||
image_feats: torch.FloatTensor,
|
||||
feature_map: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1325,6 +1389,8 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
Features extracted from the image, returned by the `image_text_embedder` method.
|
||||
feature_map:
|
||||
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
|
||||
interpolate_pos_encoding:
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
Returns:
|
||||
pred_boxes:
|
||||
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
|
||||
@ -1333,7 +1399,13 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
pred_boxes = self.box_head(image_feats)
|
||||
|
||||
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
|
||||
box_bias = self.box_bias.to(feature_map.device)
|
||||
if interpolate_pos_encoding:
|
||||
_, num_patches_height, num_patches_width, _ = feature_map.shape
|
||||
box_bias = self.compute_box_bias(num_patches_height, num_patches_width)
|
||||
else:
|
||||
box_bias = self.box_bias
|
||||
|
||||
box_bias = box_bias.to(feature_map.device)
|
||||
pred_boxes += box_bias
|
||||
pred_boxes = self.sigmoid(pred_boxes)
|
||||
return pred_boxes
|
||||
@ -1364,6 +1436,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
# Encode text and image
|
||||
outputs = self.owlvit(
|
||||
@ -1372,9 +1445,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_height = height // self.config.vision_config.patch_size
|
||||
num_patches_width = width // self.config.vision_config.patch_size
|
||||
else:
|
||||
num_patches_height = self.num_patches_height
|
||||
num_patches_width = self.num_patches_width
|
||||
|
||||
# Get image embeddings
|
||||
last_hidden_state = outputs.vision_model_output[0]
|
||||
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||||
@ -1386,11 +1468,11 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
image_embeds = self.layer_norm(image_embeds)
|
||||
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
# Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
num_patches_height,
|
||||
num_patches_width,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1403,9 +1485,20 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
# Get OwlViTModel vision embeddings (same as CLIP)
|
||||
vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True)
|
||||
vision_outputs = self.owlvit.vision_model(
|
||||
pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
|
||||
)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_height = height // self.config.vision_config.patch_size
|
||||
num_patches_width = width // self.config.vision_config.patch_size
|
||||
else:
|
||||
num_patches_height = self.num_patches_height
|
||||
num_patches_width = self.num_patches_width
|
||||
|
||||
# Apply post_layernorm to last_hidden_state, return non-projected output
|
||||
last_hidden_state = vision_outputs[0]
|
||||
@ -1418,11 +1511,11 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||||
image_embeds = self.layer_norm(image_embeds)
|
||||
|
||||
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||||
# Resize to [batch_size, num_patches_height, num_patches_width, hidden_size]
|
||||
new_size = (
|
||||
image_embeds.shape[0],
|
||||
self.sqrt_num_patches,
|
||||
self.sqrt_num_patches,
|
||||
num_patches_height,
|
||||
num_patches_width,
|
||||
image_embeds.shape[-1],
|
||||
)
|
||||
image_embeds = image_embeds.reshape(new_size)
|
||||
@ -1430,10 +1523,13 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
return (image_embeds, vision_outputs)
|
||||
|
||||
def embed_image_query(
|
||||
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
|
||||
self,
|
||||
query_image_features: torch.FloatTensor,
|
||||
query_feature_map: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
_, class_embeds = self.class_predictor(query_image_features)
|
||||
pred_boxes = self.box_predictor(query_image_features, query_feature_map)
|
||||
pred_boxes = self.box_predictor(query_image_features, query_feature_map, interpolate_pos_encoding)
|
||||
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
|
||||
|
||||
# Loop over query images
|
||||
@ -1478,6 +1574,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
query_pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> OwlViTImageGuidedObjectDetectionOutput:
|
||||
r"""
|
||||
@ -1520,26 +1617,33 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# Compute feature maps for the input and query images
|
||||
query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0]
|
||||
query_feature_map = self.image_embedder(
|
||||
pixel_values=query_pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)[0]
|
||||
feature_map, vision_outputs = self.image_embedder(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
|
||||
query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = query_feature_map.shape
|
||||
query_image_feats = torch.reshape(
|
||||
query_feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim)
|
||||
)
|
||||
# Get top class embedding and best box index for each query image in batch
|
||||
query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
|
||||
query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(
|
||||
query_image_feats, query_feature_map, interpolate_pos_encoding
|
||||
)
|
||||
|
||||
# Predict object classes [batch_size, num_patches, num_queries+1]
|
||||
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
|
||||
|
||||
# Predict object boxes
|
||||
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
||||
target_pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
@ -1574,6 +1678,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> OwlViTObjectDetectionOutput:
|
||||
r"""
|
||||
@ -1625,14 +1730,15 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
# Text and vision model outputs
|
||||
text_outputs = outputs.text_model_output
|
||||
vision_outputs = outputs.vision_model_output
|
||||
|
||||
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||||
batch_size, num_patches_height, num_patches_width, hidden_dim = feature_map.shape
|
||||
image_feats = torch.reshape(feature_map, (batch_size, num_patches_height * num_patches_width, hidden_dim))
|
||||
|
||||
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
|
||||
max_text_queries = input_ids.shape[0] // batch_size
|
||||
@ -1646,7 +1752,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||||
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
|
||||
|
||||
# Predict object boxes
|
||||
pred_boxes = self.box_predictor(image_feats, feature_map)
|
||||
pred_boxes = self.box_predictor(image_feats, feature_map, interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
|
@ -828,6 +828,144 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([[-6.2229, -8.2601]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "google/owlv2-base-patch16"
|
||||
model = Owlv2Model.from_pretrained(model_name).to(torch_device)
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
processor.image_processor.size = {"height": 1024, "width": 1024}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[-6.2520, -8.2970]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
expected_shape = torch.Size((1, 4097, 768))
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
# Owlv2ForObjectDetection part.
|
||||
model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
processor.image_processor.size = {"height": 1024, "width": 1024}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.2407, 0.0553, 0.4636], [0.1082, 0.0494, 0.1861], [0.2459, 0.0527, 0.4398]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
# Deactivate interpolate_pos_encoding on same model, and use default image size.
|
||||
# Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: self.sqrt_num_patches, self.box_bias from (OwlViTForObjectDetection).
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=False)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_default_box_bias = torch.tensor(
|
||||
[
|
||||
[-4.0717, -4.0717, -4.0717, -4.0717],
|
||||
[-3.3644, -4.0717, -4.0717, -4.0717],
|
||||
[-2.9425, -4.0717, -4.0717, -4.0717],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
|
||||
|
||||
# Interpolate with any resolution size.
|
||||
processor.image_processor.size = {"height": 1264, "width": 1024}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.2438, 0.0945, 0.4675], [0.1361, 0.0431, 0.2406], [0.2465, 0.0428, 0.4429]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
@slow
|
||||
def test_inference_object_detection(self):
|
||||
model_name = "google/owlv2-base-patch16"
|
||||
|
@ -821,6 +821,144 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
processor.image_processor.size = {"height": 800, "width": 800}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[3.6278, 0.8861]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
expected_shape = torch.Size((1, 626, 768))
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
# OwlViTForObjectDetection part.
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0680, 0.0422, 0.1347], [0.2071, 0.0450, 0.4146], [0.2000, 0.0418, 0.3476]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
# Deactivate interpolate_pos_encoding on same model, and use default image size.
|
||||
# Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: (self.sqrt_num_patch_h, self.sqrt_num_patch_w), self.box_bias from (OwlViTForObjectDetection).
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=False)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_default_box_bias = torch.tensor(
|
||||
[
|
||||
[-3.1332, -3.1332, -3.1332, -3.1332],
|
||||
[-2.3968, -3.1332, -3.1332, -3.1332],
|
||||
[-1.9452, -3.1332, -3.1332, -3.1332],
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
|
||||
|
||||
# Interpolate with any resolution size.
|
||||
processor.image_processor.size = {"height": 1264, "width": 1024}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0499, 0.0301, 0.0983], [0.2244, 0.0365, 0.4663], [0.1387, 0.0314, 0.1859]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
@slow
|
||||
def test_inference_object_detection(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
|
Loading…
Reference in New Issue
Block a user