mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
659 lines
28 KiB
Python
659 lines
28 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_ijepa.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
import collections.abc
|
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
|
from ...utils import auto_docstring, logging, torch_int
|
|
from .configuration_ijepa import IJepaConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class IJepaPatchEmbeddings(nn.Module):
|
|
"""
|
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
|
Transformer.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
image_size, patch_size = config.image_size, config.patch_size
|
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
|
|
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.num_patches = num_patches
|
|
|
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
|
|
|
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
if num_channels != self.num_channels:
|
|
raise ValueError(
|
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
|
f" Expected {self.num_channels} but got {num_channels}."
|
|
)
|
|
if not interpolate_pos_encoding:
|
|
if height != self.image_size[0] or width != self.image_size[1]:
|
|
raise ValueError(
|
|
f"Input image size ({height}*{width}) doesn't match model"
|
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
|
)
|
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
|
return embeddings
|
|
|
|
|
|
class IJepaEmbeddings(nn.Module):
|
|
"""
|
|
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
|
"""
|
|
|
|
def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
|
|
super().__init__()
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
|
self.patch_embeddings = IJepaPatchEmbeddings(config)
|
|
num_patches = self.patch_embeddings.num_patches
|
|
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.patch_size = config.patch_size
|
|
self.config = config
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
images. This method is also adapted to support torch.jit tracing.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
"""
|
|
|
|
num_patches = embeddings.shape[1]
|
|
num_positions = self.position_embeddings.shape[1]
|
|
|
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return self.position_embeddings
|
|
|
|
patch_pos_embed = self.position_embeddings
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
|
return patch_pos_embed
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
|
interpolate_pos_encoding: bool = False,
|
|
) -> torch.Tensor:
|
|
batch_size, _, height, width = pixel_values.shape
|
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
|
|
|
if bool_masked_pos is not None:
|
|
seq_length = embeddings.shape[1]
|
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
|
# replace the masked visual tokens by mask_tokens
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
|
|
|
# add positional encoding to each token
|
|
if interpolate_pos_encoding:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
embeddings = embeddings + self.position_embeddings
|
|
|
|
embeddings = self.dropout(embeddings)
|
|
|
|
return embeddings
|
|
|
|
|
|
@auto_docstring
|
|
class IJepaPreTrainedModel(PreTrainedModel):
|
|
config_class = IJepaConfig
|
|
base_model_prefix = "ijepa"
|
|
main_input_name = "pixel_values"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
|
_supports_sdpa = True
|
|
_supports_flash_attn_2 = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
|
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
|
"""Initialize the weights"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
|
# `trunc_normal_cpu` not implemented in `half` issues
|
|
module.weight.data = nn.init.trunc_normal_(
|
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
|
).to(module.weight.dtype)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
elif isinstance(module, IJepaEmbeddings):
|
|
module.position_embeddings.data = nn.init.trunc_normal_(
|
|
module.position_embeddings.data.to(torch.float32),
|
|
mean=0.0,
|
|
std=self.config.initializer_range,
|
|
).to(module.position_embeddings.dtype)
|
|
if module.mask_token is not None:
|
|
module.mask_token.data.zero_()
|
|
|
|
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
# Mask heads if we want to
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights * attention_mask
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class IJepaSelfAttention(nn.Module):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError(
|
|
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
|
f"heads {config.num_attention_heads}."
|
|
)
|
|
|
|
self.config = config
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.dropout_prob = config.attention_probs_dropout_prob
|
|
self.scaling = self.attention_head_size**-0.5
|
|
self.is_causal = False
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
|
|
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
x = x.view(new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(
|
|
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
logger.warning_once(
|
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
)
|
|
else:
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
context_layer, attention_probs = attention_interface(
|
|
self,
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
head_mask,
|
|
is_causal=self.is_causal,
|
|
scaling=self.scaling,
|
|
dropout=0.0 if not self.training else self.dropout_prob,
|
|
)
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.reshape(new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
|
|
return outputs
|
|
|
|
|
|
class IJepaSelfOutput(nn.Module):
|
|
"""
|
|
The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the
|
|
layernorm applied before each block.
|
|
"""
|
|
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class IJepaAttention(nn.Module):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.attention = IJepaSelfAttention(config)
|
|
self.output = IJepaSelfOutput(config)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads: Set[int]) -> None:
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.attention.query = prune_linear_layer(self.attention.query, index)
|
|
self.attention.key = prune_linear_layer(self.attention.key, index)
|
|
self.attention.value = prune_linear_layer(self.attention.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
|
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
|
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
|
|
|
attention_output = self.output(self_outputs[0], hidden_states)
|
|
|
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class IJepaIntermediate(nn.Module):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class IJepaOutput(nn.Module):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
hidden_states = hidden_states + input_tensor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class IJepaLayer(nn.Module):
|
|
"""This corresponds to the Block class in the timm implementation."""
|
|
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = IJepaAttention(config)
|
|
self.intermediate = IJepaIntermediate(config)
|
|
self.output = IJepaOutput(config)
|
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
|
self_attention_outputs = self.attention(
|
|
self.layernorm_before(hidden_states), # in IJepa, layernorm is applied before self-attention
|
|
head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = self_attention_outputs[0]
|
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
|
|
|
# first residual connection
|
|
hidden_states = attention_output + hidden_states
|
|
|
|
# in IJepa, layernorm is also applied after self-attention
|
|
layer_output = self.layernorm_after(hidden_states)
|
|
layer_output = self.intermediate(layer_output)
|
|
|
|
# second residual connection is done here
|
|
layer_output = self.output(layer_output, hidden_states)
|
|
|
|
outputs = (layer_output,) + outputs
|
|
|
|
return outputs
|
|
|
|
|
|
class IJepaEncoder(nn.Module):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
output_hidden_states: bool = False,
|
|
return_dict: bool = True,
|
|
) -> Union[tuple, BaseModelOutput]:
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
for i, layer_module in enumerate(self.layer):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
layer_module.__call__,
|
|
hidden_states,
|
|
layer_head_mask,
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
|
|
class IJepaPooler(nn.Module):
|
|
def __init__(self, config: IJepaConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
|
|
self.activation = ACT2FN[config.pooler_act]
|
|
|
|
def forward(self, hidden_states):
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
@auto_docstring
|
|
class IJepaModel(IJepaPreTrainedModel):
|
|
def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
|
|
r"""
|
|
add_pooling_layer (bool, *optional*, defaults to `True`):
|
|
Whether to add a pooling layer
|
|
use_mask_token (`bool`, *optional*, defaults to `False`):
|
|
Whether to use a mask token for masked image modeling.
|
|
"""
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
|
|
self.encoder = IJepaEncoder(config)
|
|
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.pooler = IJepaPooler(config) if add_pooling_layer else None
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self) -> IJepaPatchEmbeddings:
|
|
return self.embeddings.patch_embeddings
|
|
|
|
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
|
"""
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
r"""
|
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if pixel_values is None:
|
|
raise ValueError("You have to specify pixel_values")
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
|
if pixel_values.dtype != expected_dtype:
|
|
pixel_values = pixel_values.to(expected_dtype)
|
|
|
|
embedding_output = self.embeddings(
|
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
sequence_output = self.layernorm(sequence_output)
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
if not return_dict:
|
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
|
return head_outputs + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
|
|
e.g. for ImageNet.
|
|
|
|
<Tip>
|
|
|
|
Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
|
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
|
position embeddings to the higher resolution.
|
|
|
|
</Tip>
|
|
"""
|
|
)
|
|
class IJepaForImageClassification(IJepaPreTrainedModel):
|
|
def __init__(self, config: IJepaConfig) -> None:
|
|
super().__init__(config)
|
|
|
|
self.num_labels = config.num_labels
|
|
self.ijepa = IJepaModel(config, add_pooling_layer=False)
|
|
|
|
# Classifier head
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple, ImageClassifierOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.ijepa(
|
|
pixel_values,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.classifier(sequence_output.mean(dim=1))
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# move labels to correct device to enable model parallelism
|
|
labels = labels.to(logits.device)
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return ImageClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"]
|