diff --git a/docs/source/index.rst b/docs/source/index.rst index 9e2d093eb8a..0f11962cbaf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -395,7 +395,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| ViT | ❌ | ❌ | ✅ | ❌ | ❌ | +| ViT | ❌ | ❌ | ✅ | ❌ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/vit.rst b/docs/source/model_doc/vit.rst index a010a711995..3c396c54eb6 100644 --- a/docs/source/model_doc/vit.rst +++ b/docs/source/model_doc/vit.rst @@ -101,3 +101,18 @@ ViTForImageClassification .. autoclass:: transformers.ViTForImageClassification :members: forward + + +FlaxVitModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxViTModel + :members: __call__ + + +FlaxViTForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxViTForImageClassification + :members: __call__ + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 387e8b938ad..d4dba2e0616 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1553,6 +1553,7 @@ if is_flax_available(): "FlaxRobertaPreTrainedModel", ] ) + _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"]) else: from .utils import dummy_flax_objects @@ -2839,6 +2840,7 @@ if TYPE_CHECKING: FlaxRobertaModel, FlaxRobertaPreTrainedModel, ) + from .models.vit import FlaxViTForImageClassification, FlaxViTModel else: # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index a620c0a75dd..21238894787 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -97,6 +97,7 @@ if is_flax_available(): "FLAX_MODEL_MAPPING", "FlaxAutoModel", "FlaxAutoModelForCausalLM", + "FlaxAutoModelForImageClassification", "FlaxAutoModelForMaskedLM", "FlaxAutoModelForMultipleChoice", "FlaxAutoModelForNextSentencePrediction", @@ -182,6 +183,7 @@ if TYPE_CHECKING: FLAX_MODEL_MAPPING, FlaxAutoModel, FlaxAutoModelForCausalLM, + FlaxAutoModelForImageClassification, FlaxAutoModelForMaskedLM, FlaxAutoModelForMultipleChoice, FlaxAutoModelForNextSentencePrediction, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 3026db6d6bc..56af5b81f72 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -47,8 +47,9 @@ from ..roberta.modeling_flax_roberta import ( FlaxRobertaForTokenClassification, FlaxRobertaModel, ) +from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from .auto_factory import auto_class_factory -from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig +from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig logger = logging.get_logger(__name__) @@ -62,6 +63,7 @@ FLAX_MODEL_MAPPING = OrderedDict( (GPT2Config, FlaxGPT2Model), (ElectraConfig, FlaxElectraModel), (CLIPConfig, FlaxCLIPModel), + (ViTConfig, FlaxViTModel), ] ) @@ -83,6 +85,13 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ] ) +FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict( + [ + # Model for Image-classsification + (ViTConfig, FlaxViTForImageClassification), + ] +) + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Causal LM mapping @@ -134,6 +143,12 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING) +FlaxAutoModelForImageClassification = auto_class_factory( + "FlaxAutoModelForImageClassification", + FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING, + head_doc="image classification modeling", +) + FlaxAutoModelForCausalLM = auto_class_factory( "FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling" ) diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py index a8164e2bfe5..eb9c8f43081 100644 --- a/src/transformers/models/vit/__init__.py +++ b/src/transformers/models/vit/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available +from ...file_utils import _BaseLazyModule, is_flax_available, is_torch_available, is_vision_available _import_structure = { @@ -36,6 +36,9 @@ if is_torch_available(): ] +if is_flax_available(): + _import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"] + if TYPE_CHECKING: from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig @@ -50,6 +53,9 @@ if TYPE_CHECKING: ViTPreTrainedModel, ) + if is_flax_available(): + from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel + else: import importlib diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py new file mode 100644 index 00000000000..7ce86664e37 --- /dev/null +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen.attention import dot_product_attention_weights + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from .configuration_vit import ViTConfig + + +VIT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading, saving and converting weights from + PyTorch models) + + This model is also a Flax Linen `flax.linen.Module + `__ subclass. Use it as a regular Flax linen Module + and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using :class:`~transformers.ViTFeatureExtractor`. See + :meth:`transformers.ViTFeatureExtractor.__call__` for details. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class FlaxPatchEmbeddings(nn.Module): + + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + self.num_patches = num_patches + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__(self, pixel_values): + x = self.projection(pixel_values) + batch_size, _, _, channels = x.shape + return jnp.reshape(x, (batch_size, -1, channels)) + + +class FlaxViTEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class FlaxViTSelfAttention(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxViTSelfOutput(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxViTAttention(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxViTIntermediate(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxViTOutput(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + attention_output + return hidden_states + + +class FlaxViTLayer(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxViTAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype) + self.output = FlaxViTOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = attention_outputs[0] + + # first residual connection + attention_output = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(attention_output) + + hidden_states = self.intermediate(layer_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxViTLayerCollection(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxViTEncoder(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxViTPooler(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxViTPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + module_class: nn.Module = None + + def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, 3) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, pixel_values, return_dict=False)["params"] + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxViTModule(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class FlaxViTModel(FlaxViTPreTrainedModel): + module_class = FlaxViTModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, FlaxViTModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') + >>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k') + + >>> inputs = feature_extractor(images=image, return_tensors="jax") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state +""" + +overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig) + + +class FlaxViTForImageClassificationModule(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.classifier(hidden_states[:, 0, :]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + VIT_START_DOCSTRING, +) +class FlaxViTForImageClassification(FlaxViTPreTrainedModel): + module_class = FlaxViTForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') + >>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="jax") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) +""" + +overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index fddd0d36705..05c6b41f969 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -405,3 +405,17 @@ class FlaxRobertaPreTrainedModel: @classmethod def from_pretrained(self, *args, **kwargs): requires_backends(self, ["flax"]) + + +class FlaxViTForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxViTModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_vit.py b/tests/test_modeling_flax_vit.py new file mode 100644 index 00000000000..276777e0009 --- /dev/null +++ b/tests/test_modeling_flax_vit.py @@ -0,0 +1,240 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np + +from transformers import ViTConfig, is_flax_available +from transformers.testing_utils import require_flax, slow + +from .test_configuration_common import ConfigTester +from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_flax_available(): + + import jax + from transformers.models.vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel + + +class FlaxViTModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + config = ViTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, pixel_values + + def create_and_check_model(self, config, pixel_values, labels): + + model = FlaxViTModel(config=config) + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_flax +class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxViTModel, FlaxViTForImageClassification) if is_flax_available() else () + + def setUp(self) -> None: + self.model_tester = FlaxViTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + # We need to override this test because in ViT, the seq_len equals the number of patches + 1 + # we compute that here + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + num_patches = (config.image_size // config.patch_size) ** 2 + seq_length = num_patches + 1 + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length, seq_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length, seq_length], + ) + + # We neeed to override this test because ViT's forward signature is different than text models. + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + # We neeed to override this test because ViT expects pixel_values instead of input_ids + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + # We need to override this test because in ViT, the seq_len equals the number of patches + 1 + # we compute that here + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + num_patches = (config.image_size // config.patch_size) ** 2 + seq_length = num_patches + 1 # we add 1 for the [CLS] token + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + hidden_states = outputs.hidden_states + + self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("google/vit-base-patch16-224") + outputs = model(np.ones((1, 3, 224, 224))) + self.assertIsNotNone(outputs)