mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[OWL-ViT] Make model consistent with CLIP (#20144)
* Apply fix * Fix test * Remove another argument which is not used * Fix pipeline test * Add argument back, add deprecation warning * Add warning add other location * Use warnings instead * Add num_channels to config Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
parent
d3c0566679
commit
cbbeca3d17
@ -165,6 +165,8 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||||||
Number of hidden layers in the Transformer encoder.
|
Number of hidden layers in the Transformer encoder.
|
||||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
Number of channels in the input images.
|
||||||
image_size (`int`, *optional*, defaults to 768):
|
image_size (`int`, *optional*, defaults to 768):
|
||||||
The size (resolution) of each image.
|
The size (resolution) of each image.
|
||||||
patch_size (`int`, *optional*, defaults to 32):
|
patch_size (`int`, *optional*, defaults to 32):
|
||||||
@ -206,6 +208,7 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||||||
intermediate_size=3072,
|
intermediate_size=3072,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
|
num_channels=3,
|
||||||
image_size=768,
|
image_size=768,
|
||||||
patch_size=32,
|
patch_size=32,
|
||||||
hidden_act="quick_gelu",
|
hidden_act="quick_gelu",
|
||||||
@ -222,6 +225,7 @@ class OwlViTVisionConfig(PretrainedConfig):
|
|||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
""" PyTorch OWL-ViT model."""
|
""" PyTorch OWL-ViT model."""
|
||||||
|
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -516,9 +517,6 @@ OWLVIT_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
return_base_image_embeds (`bool`, *optional*):
|
|
||||||
Whether or not to return unprojected image embeddings. Set to `True` when `OwlViTModel` is called within
|
|
||||||
`OwlViTForObjectDetection`.
|
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@ -785,7 +783,6 @@ class OwlViTVisionTransformer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
use_hidden_state: Optional[bool] = True,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -809,10 +806,7 @@ class OwlViTVisionTransformer(nn.Module):
|
|||||||
last_hidden_state = encoder_outputs[0]
|
last_hidden_state = encoder_outputs[0]
|
||||||
pooled_output = last_hidden_state[:, 0, :]
|
pooled_output = last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
if use_hidden_state:
|
pooled_output = self.post_layernorm(pooled_output)
|
||||||
pooled_output = self.post_layernorm(last_hidden_state)
|
|
||||||
else:
|
|
||||||
pooled_output = self.post_layernorm(pooled_output)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||||
@ -963,7 +957,6 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
return_projected: Optional[bool] = True,
|
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -1000,10 +993,8 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
pooled_output = vision_outputs[1] # pooled_output
|
pooled_output = vision_outputs[1] # pooled_output
|
||||||
|
|
||||||
# Return projected output
|
# Return projected output
|
||||||
if return_projected:
|
image_features = self.visual_projection(pooled_output)
|
||||||
image_features = self.visual_projection(pooled_output)
|
|
||||||
else:
|
|
||||||
image_features = pooled_output
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
|
||||||
@ -1044,15 +1035,11 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Whether to return unprojected image features
|
|
||||||
return_base_image_embeds = return_base_image_embeds if return_base_image_embeds is not None else False
|
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
vision_outputs = self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
use_hidden_state=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get embeddings for all text queries in all batch samples
|
# Get embeddings for all text queries in all batch samples
|
||||||
@ -1070,12 +1057,12 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
image_embeds = self.visual_projection(image_embeds)
|
image_embeds = self.visual_projection(image_embeds)
|
||||||
|
|
||||||
# normalized features
|
# normalized features
|
||||||
image_embeds_norm = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||||||
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
# cosine similarity as logits
|
# cosine similarity as logits
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
logits_per_text = torch.matmul(text_embeds_norm, image_embeds_norm.t()) * logit_scale
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||||
logits_per_image = logits_per_text.t()
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
@ -1083,11 +1070,13 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
loss = owlvit_loss(logits_per_text)
|
loss = owlvit_loss(logits_per_text)
|
||||||
|
|
||||||
if return_base_image_embeds:
|
if return_base_image_embeds:
|
||||||
|
warnings.warn(
|
||||||
|
"`return_base_image_embeds` is deprecated and will be removed in v4.27 of Transformers, one can "
|
||||||
|
" obtain the base (unprojected) image embeddings from outputs.vision_model_output.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
last_hidden_state = vision_outputs[0]
|
last_hidden_state = vision_outputs[0]
|
||||||
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
||||||
else:
|
|
||||||
image_embeds = image_embeds_norm
|
|
||||||
text_embeds = text_embeds_norm
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||||
@ -1276,11 +1265,12 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_base_image_embeds=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resize class token
|
# Resize class token
|
||||||
image_embeds = outputs[-3]
|
last_hidden_state = outputs.vision_model_output[0]
|
||||||
|
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||||||
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
||||||
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
||||||
|
|
||||||
@ -1296,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
image_embeds.shape[-1],
|
image_embeds.shape[-1],
|
||||||
)
|
)
|
||||||
image_embeds = image_embeds.reshape(new_size)
|
image_embeds = image_embeds.reshape(new_size)
|
||||||
text_embeds = outputs[-4]
|
text_embeds = outputs.text_embeds
|
||||||
|
|
||||||
# Last hidden states from text and vision transformers
|
# Last hidden states from text and vision transformers
|
||||||
text_model_last_hidden_state = outputs[-2][0]
|
text_model_last_hidden_state = outputs[-2][0]
|
||||||
|
@ -120,7 +120,7 @@ class OwlViTVisionModelTester:
|
|||||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||||
num_patches = (self.image_size // self.patch_size) ** 2
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
Loading…
Reference in New Issue
Block a user