* beit-flax

* updated FLAX_BEIT_MLM_DOCSTRING

* removed bool_masked_pos from classification

* updated Copyright

* code refactoring: x -> embeddings

* updated test: rm from_pt

* Update docs/source/model_doc/beit.rst

* model code dtype updates and
other changes according to review

* relative_position_bias
revert back to pytorch design
This commit is contained in:
Kamal Raj 2021-09-21 17:04:19 +05:30 committed by GitHub
parent 48fa42e5d5
commit a2dec768a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1346 additions and 3 deletions

View File

@ -342,7 +342,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BeiT | ❌ | ❌ | ✅ | ❌ | |
| BeiT | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

View File

@ -59,7 +59,8 @@ Tips:
:obj:`use_relative_position_bias` attribute of :class:`~transformers.BeitConfig` to :obj:`True` in order to add
position embeddings.
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code can be found `here
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JAX/FLAX version of this model was
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/microsoft/unilm/tree/master/beit>`__.
BeitConfig
@ -95,3 +96,24 @@ BeitForImageClassification
.. autoclass:: transformers.BeitForImageClassification
:members: forward
FlaxBeitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBeitModel
:members: __call__
FlaxBeitForMaskedImageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBeitForMaskedImageModeling
:members: __call__
FlaxBeitForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBeitForImageClassification
:members: __call__

View File

@ -1751,6 +1751,14 @@ if is_flax_available():
"FlaxBartPreTrainedModel",
]
)
_import_structure["models.beit"].extend(
[
"FlaxBeitForImageClassification",
"FlaxBeitForMaskedImageModeling",
"FlaxBeitModel",
"FlaxBeitPreTrainedModel",
]
)
_import_structure["models.bert"].extend(
[
"FlaxBertForMaskedLM",
@ -3324,6 +3332,12 @@ if TYPE_CHECKING:
FlaxBartModel,
FlaxBartPreTrainedModel,
)
from .models.beit import (
FlaxBeitForImageClassification,
FlaxBeitForMaskedImageModeling,
FlaxBeitModel,
FlaxBeitPreTrainedModel,
)
from .models.bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,

View File

@ -33,6 +33,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"),
("beit", "FlaxBeitModel"),
("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"),
("gpt2", "FlaxGPT2Model"),
@ -95,6 +96,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "FlaxViTForImageClassification"),
("beit", "FlaxBeitForImageClassification"),
]
)

View File

@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available, is_vision_available
_import_structure = {
@ -37,6 +37,15 @@ if is_torch_available():
"BeitPreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_beit"] = [
"FlaxBeitForImageClassification",
"FlaxBeitForMaskedImageModeling",
"FlaxBeitModel",
"FlaxBeitPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
@ -52,6 +61,14 @@ if TYPE_CHECKING:
BeitPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_beit import (
FlaxBeitForImageClassification,
FlaxBeitForMaskedImageModeling,
FlaxBeitModel,
FlaxBeitPreTrainedModel,
)
else:
import sys

View File

@ -0,0 +1,886 @@
# coding=utf-8
# Copyright 2021 Microsoft Research and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
FlaxMaskedLMOutput,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from .configuration_beit import BeitConfig
BEIT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.BeitConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
BEIT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using :class:`~transformers.BeitFeatureExtractor`. See
:meth:`transformers.BeitFeatureExtractor.__call__` for details.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray:
"""
get pair-wise relative position index for each token inside the window
"""
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
coords_h = np.arange(window_size[0])
coords_w = np.arange(window_size[1])
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = np.reshape(coords, (2, -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return jnp.array(relative_position_index)
def ones_with_scale(key, shape, scale, dtype=jnp.float32):
return jnp.ones(shape, dtype) * scale
class FlaxBeitDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
rate: float
@nn.module.compact
def __call__(self, inputs, deterministic: Optional[bool] = True):
if self.rate == 0.0:
return inputs
keep_prob = 1.0 - self.rate
if deterministic:
return inputs
else:
shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
rng = self.make_rng("droppath")
random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype)
binary_tensor = jnp.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
class FlaxBeitPatchEmbeddings(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
image_size = self.config.image_size
patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size)
patch_shape = (image_size // patch_size, image_size // patch_size)
self.num_patches = num_patches
self.patch_shape = patch_shape
self.projection = nn.Conv(
self.config.hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, pixel_values):
embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels))
class FlaxBeitEmbeddings(nn.Module):
"""Construct the CLS token, position and patch embeddings."""
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
if self.config.use_mask_token:
self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches
if self.config.use_absolute_position_embeddings:
self.position_embeddings = self.param(
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True):
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.shape
cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
cls_tokens = cls_tokens.astype(embeddings.dtype)
if bool_masked_pos is not None:
mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size))
mask_tokens = mask_tokens.astype(embeddings.dtype)
# replace the masked visual tokens by mask_tokens
w = jnp.expand_dims(bool_masked_pos, axis=-1)
embeddings = embeddings * (1 - w) + mask_tokens * w
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
if self.config.use_absolute_position_embeddings:
embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype)
embeddings = self.dropout(embeddings, deterministic=deterministic)
return embeddings
class FlaxBeitRelativePositionBias(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3
self.relative_position_bias_table = self.param(
"relative_position_bias_table",
nn.initializers.zeros,
(num_relative_distance, self.config.num_attention_heads),
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
self.relative_position_index = relative_position_index_init(self.window_size)
def __call__(self):
index = self.relative_position_index.reshape(-1)
shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1)
relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH
return jnp.transpose(relative_position_bias, (2, 0, 1))
class FlaxBeitSelfAttention(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr(
self.config, "embedding_size"
):
raise ValueError(
f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention "
f"heads {self.config.num_attention_heads}."
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
use_bias=False,
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.relative_position_bias = (
FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype)
if self.window_size
else None
)
def __call__(
self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
dropout_rng = None
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attention_bias = jnp.array(0.0, dtype=self.dtype)
# Add relative position bias if present.
if self.relative_position_bias is not None:
attention_bias = jnp.expand_dims(self.relative_position_bias(), 0)
attention_bias = attention_bias.astype(query_states.dtype)
# Add shared relative position bias if provided.
if relative_position_bias is not None:
attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype)
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxBeitSelfOutput(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxBeitAttention(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype)
self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype)
def __call__(
self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False
):
attn_outputs = self.attention(
hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
)
attn_output = attn_outputs[0]
attn_output = self.output(attn_output, deterministic=deterministic)
outputs = (attn_output,)
if output_attentions:
outputs += (attn_outputs[1],)
return outputs
class FlaxBeitIntermediate(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxBeitOutput(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxBeitLayer(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
drop_path_rate: float
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype)
self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBeitOutput(self.config, dtype=self.dtype)
self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate)
self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.init_values = self.config.layer_scale_init_value
if self.init_values > 0:
self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values)
self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values)
else:
self.lambda_1 = None
self.lambda_2 = None
def __call__(
self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
):
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
relative_position_bias,
deterministic=deterministic,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
# apply lambda_1 if present
if self.lambda_1 is not None:
attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output
# first residual connection
hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states
# in BEiT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, deterministic=deterministic)
# apply lambda_2 if present
if self.lambda_2 is not None:
layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output
# second residual connection
layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states
outputs = (layer_output,)
if output_attentions:
outputs += (self_attention_outputs[1],)
return outputs
class FlaxBeitLayerCollection(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
drop_path_rates: List[float]
relative_position_bias: Callable[[], jnp.ndarray]
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxBeitLayer(
self.config,
window_size=self.window_size if self.config.use_relative_position_bias else None,
drop_path_rate=self.drop_path_rates[i],
name=str(i),
dtype=self.dtype,
)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None
layer_outputs = layer(
hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxBeitEncoder(nn.Module):
config: BeitConfig
window_size: Tuple[int, int]
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.use_shared_relative_position_bias:
self.relative_position_bias = FlaxBeitRelativePositionBias(
config=self.config, window_size=self.window_size, dtype=self.dtype
)
# stochastic depth decay rule
drop_path_rates = [x for x in np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)]
self.layer = FlaxBeitLayerCollection(
self.config,
window_size=self.window_size,
drop_path_rates=drop_path_rates,
relative_position_bias=self.relative_position_bias
if self.config.use_shared_relative_position_bias
else None,
dtype=self.dtype,
)
def __call__(
self,
hidden_states,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BeitConfig
base_model_prefix = "beit"
module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
params_rng, dropout_rng = jax.random.split(rng)
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
pixel_values,
bool_masked_pos=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
rngs["dropout"] = dropout_rng
rngs["droppath"] = droppath_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
bool_masked_pos,
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxBeitPooler(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.use_mean_pooling:
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(self, hidden_states):
if self.config.use_mean_pooling:
# Mean pool the final hidden states of the patch tokens
patch_tokens = hidden_states[:, 1:, :]
pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))
else:
# Pool by simply taking the final hidden state of the [CLS] token
pooled_output = hidden_states[:, 0]
return pooled_output
class FlaxBeitModule(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBeitEncoder(
self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype
)
if not self.config.use_mean_pooling:
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
def __call__(
self,
pixel_values,
bool_masked_pos=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if not self.config.use_mean_pooling:
hidden_states = self.layernorm(hidden_states)
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
BEIT_START_DOCSTRING,
)
class FlaxBeitModel(FlaxBeitPreTrainedModel):
module_class = FlaxBeitModule
FLAX_BEIT_MODEL_DOCSTRING = """
Returns:
Examples::
>>> from transformers import BeitFeatureExtractor, FlaxBeitModel
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
>>> model = FlaxBeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
>>> inputs = feature_extractor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig)
class FlaxBeitForMaskedImageModelingModule(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype)
# Classifier head
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(
self,
pixel_values=None,
bool_masked_pos=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.beit(
pixel_values,
bool_masked_pos,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.layernorm(sequence_output)
prediction_scores = self.lm_head(sequence_output[:, 1:])
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return output
return FlaxMaskedLMOutput(
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
BEIT_START_DOCSTRING,
)
class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel):
module_class = FlaxBeitForMaskedImageModelingModule
FLAX_BEIT_MLM_DOCSTRING = """
bool_masked_pos (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Returns:
Examples::
>>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
>>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
>>> inputs = feature_extractor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING)
append_replace_return_docstrings(
FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig
)
class FlaxBeitForImageClassificationModule(nn.Module):
config: BeitConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True)
self.classifier = nn.Dense(
self.config.num_labels,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(
self,
pixel_values=None,
bool_masked_pos=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.beit(
pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
hidden states of the patch tokens) e.g. for ImageNet.
""",
BEIT_START_DOCSTRING,
)
class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel):
module_class = FlaxBeitForImageClassificationModule
FLAX_BEIT_CLASSIF_DOCSTRING = """
Returns:
Example::
>>> from transformers import BeitFeatureExtractor, FlaxBeitForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
>>> model = FlaxBeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
>>> inputs = feature_extractor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
"""
overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING)
append_replace_return_docstrings(
FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig
)

View File

@ -321,6 +321,38 @@ class FlaxBartPreTrainedModel:
requires_backends(cls, ["flax"])
class FlaxBeitForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBeitForMaskedImageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBeitModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBeitPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -0,0 +1,369 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
from transformers import BeitConfig
from transformers.file_utils import cached_property, is_flax_available, is_vision_available
from transformers.testing_utils import require_flax, require_vision, slow
from .test_configuration_common import ConfigTester
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor
if is_flax_available():
import jax
from transformers import FlaxBeitForImageClassification, FlaxBeitForMaskedImageModeling, FlaxBeitModel
if is_vision_available():
from PIL import Image
from transformers import BeitFeatureExtractor
class FlaxBeitModelTester(unittest.TestCase):
def __init__(
self,
parent,
vocab_size=100,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
):
self.parent = parent
self.vocab_size = vocab_size
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = BeitConfig(
vocab_size=self.vocab_size,
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels):
model = FlaxBeitModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels):
model = FlaxBeitForMaskedImageModeling(config=config)
result = model(pixel_values)
# expected sequence length = num_patches
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = FlaxBeitForImageClassification(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_flax
class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(FlaxBeitModel, FlaxBeitForImageClassification, FlaxBeitForMaskedImageModeling) if is_flax_available() else ()
)
def setUp(self) -> None:
self.model_tester = FlaxBeitModelTester(self)
self.config_tester = ConfigTester(self, config_class=BeitConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
# We neeed to override this test because Beit's forward signature is different than text models.
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
# We neeed to override this test because Beit expects pixel_values instead of input_ids
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(pixel_values, **kwargs):
return model(pixel_values=pixel_values, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# We need to override this test because in Beit, the seq_len equals the number of patches + 1
# we compute that here
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
num_patches = (config.image_size // config.patch_size) ** 2
seq_length = num_patches + 1 # we add 1 for the [CLS] token
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("microsoft/beit-base-patch16-224")
outputs = model(np.ones((1, 3, 224, 224)))
self.assertIsNotNone(outputs)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
class FlaxBeitModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
)
@slow
def test_inference_masked_image_modeling_head(self):
model = FlaxBeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
feature_extractor = self.default_feature_extractor
image = prepare_img()
pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
# prepare bool_masked_pos
bool_masked_pos = np.ones((1, 196), dtype=np.bool)
# forward pass
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
logits = outputs.logits
# verify the logits
expected_shape = (1, 196, 8192)
self.assertEqual(logits.shape, expected_shape)
expected_slice = np.array(
[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]
)
self.assertTrue(np.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))
@slow
def test_inference_image_classification_head_imagenet_1k(self):
model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="np")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = (1, 1000)
self.assertEqual(logits.shape, expected_shape)
expected_slice = np.array([-1.2385, -1.0987, -1.0108])
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
expected_class_idx = 281
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
@slow
def test_inference_image_classification_head_imagenet_22k(self):
model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-large-patch16-224-pt22k-ft22k")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="np")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = (1, 21841)
self.assertEqual(logits.shape, expected_shape)
expected_slice = np.array([1.6881, -0.2787, 0.5901])
self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))
expected_class_idx = 2396
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)

View File

@ -97,6 +97,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"FlaxBeitForMaskedImageModeling",
"BeitForMaskedImageModeling",
"CLIPTextModel",
"CLIPVisionModel",