mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Flax VisionTransformer (#11951)
* adding vit for flax * added test for Flax-vit and some bug-fixes * overrided methods where variable changes were necessary for flax_vit test * added FlaxViTForImageClassification for test * Update src/transformers/models/vit/modeling_flax_vit.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * made changes suggested in PR * Adding jax-vit models for autoimport * swapping num_channels and height,width dimension * fixing the docstring for torch-like inputs for VIT * add model to main init * add docs * doc, fix-copies * docstrings * small test fixes * fix docs * fix docstr * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * style Co-authored-by: jayendra <jayendra@infocusp.in> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
0eaeae2e36
commit
9a9314f6d9
@ -395,7 +395,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
@ -101,3 +101,18 @@ ViTForImageClassification
|
||||
|
||||
.. autoclass:: transformers.ViTForImageClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
FlaxVitModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxViTModel
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxViTForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxViTForImageClassification
|
||||
:members: __call__
|
||||
|
||||
|
@ -1553,6 +1553,7 @@ if is_flax_available():
|
||||
"FlaxRobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
@ -2839,6 +2840,7 @@ if TYPE_CHECKING:
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaPreTrainedModel,
|
||||
)
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
@ -97,6 +97,7 @@ if is_flax_available():
|
||||
"FLAX_MODEL_MAPPING",
|
||||
"FlaxAutoModel",
|
||||
"FlaxAutoModelForCausalLM",
|
||||
"FlaxAutoModelForImageClassification",
|
||||
"FlaxAutoModelForMaskedLM",
|
||||
"FlaxAutoModelForMultipleChoice",
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
@ -182,6 +183,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_MAPPING,
|
||||
FlaxAutoModel,
|
||||
FlaxAutoModelForCausalLM,
|
||||
FlaxAutoModelForImageClassification,
|
||||
FlaxAutoModelForMaskedLM,
|
||||
FlaxAutoModelForMultipleChoice,
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
|
@ -47,8 +47,9 @@ from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig
|
||||
from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -62,6 +63,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
||||
(GPT2Config, FlaxGPT2Model),
|
||||
(ElectraConfig, FlaxElectraModel),
|
||||
(CLIPConfig, FlaxCLIPModel),
|
||||
(ViTConfig, FlaxViTModel),
|
||||
]
|
||||
)
|
||||
|
||||
@ -83,6 +85,13 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
(ViTConfig, FlaxViTForImageClassification),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
@ -134,6 +143,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
|
||||
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
||||
|
||||
FlaxAutoModelForImageClassification = auto_class_factory(
|
||||
"FlaxAutoModelForImageClassification",
|
||||
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING,
|
||||
head_doc="image classification modeling",
|
||||
)
|
||||
|
||||
FlaxAutoModelForCausalLM = auto_class_factory(
|
||||
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
|
||||
)
|
||||
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available
|
||||
from ...file_utils import _BaseLazyModule, is_flax_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -36,6 +36,9 @@ if is_torch_available():
|
||||
]
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
@ -50,6 +53,9 @@ if TYPE_CHECKING:
|
||||
ViTPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
606
src/transformers/models/vit/modeling_flax_vit.py
Normal file
606
src/transformers/models/vit/modeling_flax_vit.py
Normal file
@ -0,0 +1,606 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
|
||||
from ...modeling_flax_utils import (
|
||||
ACT2FN,
|
||||
FlaxPreTrainedModel,
|
||||
append_replace_return_docstrings,
|
||||
overwrite_call_docstring,
|
||||
)
|
||||
from .configuration_vit import ViTConfig
|
||||
|
||||
|
||||
VIT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
generic methods the library implements for all its model (such as downloading, saving and converting weights from
|
||||
PyTorch models)
|
||||
|
||||
This model is also a Flax Linen `flax.linen.Module
|
||||
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
|
||||
and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
"""
|
||||
|
||||
VIT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using :class:`~transformers.ViTFeatureExtractor`. See
|
||||
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxPatchEmbeddings(nn.Module):
|
||||
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
image_size = self.config.image_size
|
||||
patch_size = self.config.patch_size
|
||||
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
||||
self.num_patches = num_patches
|
||||
self.projection = nn.Conv(
|
||||
self.config.hidden_size,
|
||||
kernel_size=(patch_size, patch_size),
|
||||
strides=(patch_size, patch_size),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, pixel_values):
|
||||
x = self.projection(pixel_values)
|
||||
batch_size, _, _, channels = x.shape
|
||||
return jnp.reshape(x, (batch_size, -1, channels))
|
||||
|
||||
|
||||
class FlaxViTEmbeddings(nn.Module):
|
||||
"""Construct the CLS token, position and patch embeddings."""
|
||||
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
||||
self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = self.param(
|
||||
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, pixel_values, deterministic=True):
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
|
||||
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings, deterministic=deterministic)
|
||||
return embeddings
|
||||
|
||||
|
||||
class FlaxViTSelfAttention(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
attn_weights = dot_product_attention_weights(
|
||||
query_states,
|
||||
key_states,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
broadcast_dropout=True,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
||||
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxViTSelfOutput(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxViTAttention(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype)
|
||||
self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
|
||||
attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
|
||||
attn_output = attn_outputs[0]
|
||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_outputs[1],)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxViTIntermediate(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxViTOutput(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = hidden_states + attention_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxViTLayer(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
|
||||
self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxViTOutput(self.config, dtype=self.dtype)
|
||||
self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
|
||||
attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# first residual connection
|
||||
attention_output = attention_output + hidden_states
|
||||
|
||||
# in ViT, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(attention_output)
|
||||
|
||||
hidden_states = self.intermediate(layer_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxViTLayerCollection(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class FlaxViTEncoder(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
return self.layer(
|
||||
hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class FlaxViTPooler(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
cls_hidden_state = hidden_states[:, 0]
|
||||
cls_hidden_state = self.dense(cls_hidden_state)
|
||||
return nn.tanh(cls_hidden_state)
|
||||
|
||||
|
||||
class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
if input_shape is None:
|
||||
input_shape = (1, config.image_size, config.image_size, 3)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxViTModule(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype)
|
||||
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
|
||||
|
||||
outputs = self.encoder(
|
||||
hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
||||
|
||||
if not return_dict:
|
||||
# if pooled is None, don't return it
|
||||
if pooled is None:
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxViTModel(FlaxViTPreTrainedModel):
|
||||
module_class = FlaxViTModule
|
||||
|
||||
|
||||
FLAX_VISION_MODEL_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, FlaxViTModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||
>>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING)
|
||||
append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig)
|
||||
|
||||
|
||||
class FlaxViTForImageClassificationModule(nn.Module):
|
||||
config: ViTConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values=None,
|
||||
deterministic: bool = True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.classifier(hidden_states[:, 0, :])
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return output
|
||||
|
||||
return FlaxSequenceClassifierOutput(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||
the [CLS] token) e.g. for ImageNet.
|
||||
""",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
|
||||
module_class = FlaxViTForImageClassificationModule
|
||||
|
||||
|
||||
FLAX_VISION_CLASSIF_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification
|
||||
>>> from PIL import Image
|
||||
>>> import jax
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
||||
>>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
||||
append_replace_return_docstrings(
|
||||
FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig
|
||||
)
|
@ -405,3 +405,17 @@ class FlaxRobertaPreTrainedModel:
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
240
tests/test_modeling_flax_vit.py
Normal file
240
tests/test_modeling_flax_vit.py
Normal file
@ -0,0 +1,240 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ViTConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
import jax
|
||||
from transformers.models.vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
|
||||
|
||||
class FlaxViTModelTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
config = ViTConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
model = FlaxViTModel(config=config)
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxViTModel, FlaxViTForImageClassification) if is_flax_available() else ()
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_tester = FlaxViTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# We need to override this test because in ViT, the seq_len equals the number of patches + 1
|
||||
# we compute that here
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
seq_length = num_patches + 1
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||
)
|
||||
|
||||
# We neeed to override this test because ViT's forward signature is different than text models.
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
# We neeed to override this test because ViT expects pixel_values instead of input_ids
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(pixel_values, **kwargs):
|
||||
return model(pixel_values=pixel_values, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# We need to override this test because in ViT, the seq_len equals the number of patches + 1
|
||||
# we compute that here
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
seq_length = num_patches + 1 # we add 1 for the [CLS] token
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("google/vit-base-patch16-224")
|
||||
outputs = model(np.ones((1, 3, 224, 224)))
|
||||
self.assertIsNotNone(outputs)
|
Loading…
Reference in New Issue
Block a user