mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Resnet flax (#21472)
* [WIP] flax resnet * added pretrained flax models, results reproducible * Added pretrained flax models, results reproducible * working on tests * no real code change, just some comments * [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * removing loss from modeling flax output class * fixing classifier tests * fixing comments, model output * cleaning comments * review changes * review changes * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * renaming Flax to PyTorch --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
88dae78f4d
commit
a0cbbba31f
@ -285,7 +285,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
|
@ -71,3 +71,13 @@ If you're interested in submitting a resource to be included here, please feel f
|
||||
|
||||
[[autodoc]] TFResNetForImageClassification
|
||||
- call
|
||||
|
||||
## FlaxResNetModel
|
||||
|
||||
[[autodoc]] FlaxResNetModel
|
||||
- __call__
|
||||
|
||||
## FlaxResNetForImageClassification
|
||||
|
||||
[[autodoc]] FlaxResNetForImageClassification
|
||||
- __call__
|
||||
|
@ -237,7 +237,7 @@ Flax), PyTorch y/o TensorFlow.
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -254,7 +254,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -339,7 +339,7 @@ specific language governing permissions and limitations under the License.
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
|
@ -308,7 +308,7 @@ specific language governing permissions and limitations under the License.
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
@ -252,7 +252,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
@ -3637,6 +3637,9 @@ else:
|
||||
"FlaxPegasusPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.resnet"].extend(
|
||||
["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"FlaxRobertaForCausalLM",
|
||||
@ -6692,6 +6695,7 @@ if TYPE_CHECKING:
|
||||
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
|
||||
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
|
||||
from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel
|
||||
from .models.roberta import (
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
|
@ -45,6 +45,64 @@ class FlaxBaseModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
|
||||
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
|
||||
model at the output of each layer plus the optional initial embedding outputs.
|
||||
"""
|
||||
|
||||
last_hidden_state: jnp.ndarray = None
|
||||
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
|
||||
Last layer hidden-state after a pooling operation on the spatial dimensions.
|
||||
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
|
||||
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
|
||||
model at the output of each layer plus the optional initial embedding outputs.
|
||||
"""
|
||||
|
||||
last_hidden_state: jnp.ndarray = None
|
||||
pooler_output: jnp.ndarray = None
|
||||
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of image classification models.
|
||||
|
||||
Args:
|
||||
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
|
||||
`config.output_hidden_states=True`):
|
||||
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
|
||||
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
|
||||
called feature maps) of the model at the output of each stage.
|
||||
"""
|
||||
|
||||
logits: jnp.ndarray = None
|
||||
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mt5", "FlaxMT5Model"),
|
||||
("opt", "FlaxOPTModel"),
|
||||
("pegasus", "FlaxPegasusModel"),
|
||||
("resnet", "FlaxResNetModel"),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
|
||||
("roformer", "FlaxRoFormerModel"),
|
||||
@ -119,6 +120,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("beit", "FlaxBeitForImageClassification"),
|
||||
("resnet", "FlaxResNetForImageClassification"),
|
||||
("vit", "FlaxViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -47,6 +53,17 @@ else:
|
||||
"TFResNetPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_resnet"] = [
|
||||
"FlaxResNetForImageClassification",
|
||||
"FlaxResNetModel",
|
||||
"FlaxResNetPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
|
||||
@ -78,6 +95,14 @@ if TYPE_CHECKING:
|
||||
TFResNetPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
701
src/transformers/models/resnet/modeling_flax_resnet.py
Normal file
701
src/transformers/models/resnet/modeling_flax_resnet.py
Normal file
@ -0,0 +1,701 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutputWithNoAttention,
|
||||
FlaxBaseModelOutputWithPoolingAndNoAttention,
|
||||
FlaxImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_flax_utils import (
|
||||
ACT2FN,
|
||||
FlaxPreTrainedModel,
|
||||
append_replace_return_docstrings,
|
||||
overwrite_call_docstring,
|
||||
)
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from .configuration_resnet import ResNetConfig
|
||||
|
||||
|
||||
RESNET_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
||||
|
||||
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
||||
`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given `dtype`.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
||||
[`~FlaxPreTrainedModel.to_bf16`].
|
||||
"""
|
||||
|
||||
|
||||
RESNET_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||||
[`AutoImageProcessor.__call__`] for details.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""Identity function."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FlaxResNetConvLayer(nn.Module):
|
||||
out_channels: int
|
||||
kernel_size: int = 3
|
||||
stride: int = 1
|
||||
activation: Optional[str] = "relu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.convolution = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(self.kernel_size, self.kernel_size),
|
||||
strides=self.stride,
|
||||
padding=self.kernel_size // 2,
|
||||
dtype=self.dtype,
|
||||
use_bias=False,
|
||||
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
|
||||
)
|
||||
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
|
||||
self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity()
|
||||
|
||||
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
hidden_state = self.convolution(x)
|
||||
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
|
||||
hidden_state = self.activation_func(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetEmbeddings(nn.Module):
|
||||
"""
|
||||
ResNet Embeddings (stem) composed of a single aggressive convolution.
|
||||
"""
|
||||
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.embedder = FlaxResNetConvLayer(
|
||||
self.config.embedding_size,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
activation=self.config.hidden_act,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
|
||||
|
||||
def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
num_channels = pixel_values.shape[-1]
|
||||
if num_channels != self.config.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
embedding = self.embedder(pixel_values, deterministic=deterministic)
|
||||
embedding = self.max_pool(embedding)
|
||||
return embedding
|
||||
|
||||
|
||||
class FlaxResNetShortCut(nn.Module):
|
||||
"""
|
||||
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
|
||||
downsample the input using `stride=2`.
|
||||
"""
|
||||
|
||||
out_channels: int
|
||||
stride: int = 2
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.convolution = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(1, 1),
|
||||
strides=self.stride,
|
||||
use_bias=False,
|
||||
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
|
||||
|
||||
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
hidden_state = self.convolution(x)
|
||||
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetBasicLayerCollection(nn.Module):
|
||||
out_channels: int
|
||||
stride: int = 1
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.layer = [
|
||||
FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype),
|
||||
FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype),
|
||||
]
|
||||
|
||||
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
for layer in self.layer:
|
||||
hidden_state = layer(hidden_state, deterministic=deterministic)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetBasicLayer(nn.Module):
|
||||
"""
|
||||
A classic ResNet's residual layer composed by two `3x3` convolutions.
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
stride: int = 1
|
||||
activation: Optional[str] = "relu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
|
||||
self.shortcut = (
|
||||
FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype)
|
||||
if should_apply_shortcut
|
||||
else None
|
||||
)
|
||||
self.layer = FlaxResNetBasicLayerCollection(
|
||||
out_channels=self.out_channels,
|
||||
stride=self.stride,
|
||||
activation=self.activation,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation_func = ACT2FN[self.activation]
|
||||
|
||||
def __call__(self, hidden_state, deterministic: bool = True):
|
||||
residual = hidden_state
|
||||
hidden_state = self.layer(hidden_state, deterministic=deterministic)
|
||||
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, deterministic=deterministic)
|
||||
hidden_state += residual
|
||||
|
||||
hidden_state = self.activation_func(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetBottleNeckLayerCollection(nn.Module):
|
||||
out_channels: int
|
||||
stride: int = 1
|
||||
activation: Optional[str] = "relu"
|
||||
reduction: int = 4
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
reduces_channels = self.out_channels // self.reduction
|
||||
|
||||
self.layer = [
|
||||
FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"),
|
||||
FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"),
|
||||
FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"),
|
||||
]
|
||||
|
||||
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
for layer in self.layer:
|
||||
hidden_state = layer(hidden_state, deterministic=deterministic)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetBottleNeckLayer(nn.Module):
|
||||
"""
|
||||
A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the
|
||||
input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution
|
||||
remaps the reduced features to `out_channels`.
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
stride: int = 1
|
||||
activation: Optional[str] = "relu"
|
||||
reduction: int = 4
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
|
||||
self.shortcut = (
|
||||
FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype)
|
||||
if should_apply_shortcut
|
||||
else None
|
||||
)
|
||||
|
||||
self.layer = FlaxResNetBottleNeckLayerCollection(
|
||||
self.out_channels,
|
||||
stride=self.stride,
|
||||
activation=self.activation,
|
||||
reduction=self.reduction,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.activation_func = ACT2FN[self.activation]
|
||||
|
||||
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
residual = hidden_state
|
||||
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, deterministic=deterministic)
|
||||
hidden_state = self.layer(hidden_state, deterministic)
|
||||
hidden_state += residual
|
||||
hidden_state = self.activation_func(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetStageLayersCollection(nn.Module):
|
||||
"""
|
||||
A ResNet stage composed by stacked layers.
|
||||
"""
|
||||
|
||||
config: ResNetConfig
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
stride: int = 2
|
||||
depth: int = 2
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer
|
||||
|
||||
layers = [
|
||||
# downsampling is done in the first layer with stride of 2
|
||||
layer(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
stride=self.stride,
|
||||
activation=self.config.hidden_act,
|
||||
dtype=self.dtype,
|
||||
name="0",
|
||||
),
|
||||
]
|
||||
|
||||
for i in range(self.depth - 1):
|
||||
layers.append(
|
||||
layer(
|
||||
self.out_channels,
|
||||
self.out_channels,
|
||||
activation=self.config.hidden_act,
|
||||
dtype=self.dtype,
|
||||
name=str(i + 1),
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = layers
|
||||
|
||||
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
hidden_state = x
|
||||
for layer in self.layers:
|
||||
hidden_state = layer(hidden_state, deterministic=deterministic)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxResNetStage(nn.Module):
|
||||
"""
|
||||
A ResNet stage composed by stacked layers.
|
||||
"""
|
||||
|
||||
config: ResNetConfig
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
stride: int = 2
|
||||
depth: int = 2
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.layers = FlaxResNetStageLayersCollection(
|
||||
self.config,
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
stride=self.stride,
|
||||
depth=self.depth,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
||||
return self.layers(x, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxResNetStageCollection(nn.Module):
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:])
|
||||
stages = [
|
||||
FlaxResNetStage(
|
||||
self.config,
|
||||
self.config.embedding_size,
|
||||
self.config.hidden_sizes[0],
|
||||
stride=2 if self.config.downsample_in_first_stage else 1,
|
||||
depth=self.config.depths[0],
|
||||
dtype=self.dtype,
|
||||
name="0",
|
||||
)
|
||||
]
|
||||
|
||||
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])):
|
||||
stages.append(
|
||||
FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1))
|
||||
)
|
||||
|
||||
self.stages = stages
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: jnp.ndarray,
|
||||
output_hidden_states: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> FlaxBaseModelOutputWithNoAttention:
|
||||
hidden_states = () if output_hidden_states else None
|
||||
|
||||
for stage_module in self.stages:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
|
||||
|
||||
hidden_state = stage_module(hidden_state, deterministic=deterministic)
|
||||
|
||||
return hidden_state, hidden_states
|
||||
|
||||
|
||||
class FlaxResNetEncoder(nn.Module):
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: jnp.ndarray,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
deterministic: bool = True,
|
||||
) -> FlaxBaseModelOutputWithNoAttention:
|
||||
hidden_state, hidden_states = self.stages(
|
||||
hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithNoAttention(
|
||||
last_hidden_state=hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
class FlaxResNetPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ResNetConfig
|
||||
base_model_prefix = "resnet"
|
||||
main_input_name = "pixel_values"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ResNetConfig,
|
||||
input_shape=(1, 224, 224, 3),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
if input_shape is None:
|
||||
input_shape = (1, config.image_size, config.image_size, config.num_channels)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||
|
||||
rngs = {"params": rng}
|
||||
|
||||
random_params = self.module.init(rngs, pixel_values, return_dict=False)
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values,
|
||||
params: dict = None,
|
||||
train: bool = False,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
|
||||
return self.module.apply(
|
||||
{
|
||||
"params": params["params"] if params is not None else self.params["params"],
|
||||
"batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"],
|
||||
},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True
|
||||
)
|
||||
|
||||
|
||||
class FlaxResNetModule(nn.Module):
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype)
|
||||
|
||||
# Adaptive average pooling used in resnet
|
||||
self.pooler = partial(
|
||||
nn.avg_pool,
|
||||
padding=((0, 0), (0, 0)),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values,
|
||||
deterministic: bool = True,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> FlaxBaseModelOutputWithPoolingAndNoAttention:
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
embedding_output = self.embedder(pixel_values, deterministic=deterministic)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
|
||||
pooled_output = self.pooler(
|
||||
last_hidden_state,
|
||||
window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
|
||||
strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
|
||||
).transpose(0, 3, 1, 2)
|
||||
|
||||
last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPoolingAndNoAttention(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ResNet model outputting raw features without any specific head on top.",
|
||||
RESNET_START_DOCSTRING,
|
||||
)
|
||||
class FlaxResNetModel(FlaxResNetPreTrainedModel):
|
||||
module_class = FlaxResNetModule
|
||||
|
||||
|
||||
FLAX_VISION_MODEL_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, FlaxResNetModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
|
||||
>>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50")
|
||||
>>> inputs = image_processor(images=image, return_tensors="np")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
```
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING)
|
||||
append_replace_return_docstrings(
|
||||
FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig
|
||||
)
|
||||
|
||||
|
||||
class FlaxResNetClassifierCollection(nn.Module):
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1")
|
||||
|
||||
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
class FlaxResNetForImageClassificationModule(nn.Module):
|
||||
config: ResNetConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype)
|
||||
|
||||
if self.config.num_labels > 0:
|
||||
self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype)
|
||||
else:
|
||||
self.classifier = Identity()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
pixel_values=None,
|
||||
deterministic: bool = True,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.resnet(
|
||||
pixel_values,
|
||||
deterministic=deterministic,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
||||
|
||||
logits = self.classifier(pooled_output[:, :, 0, 0])
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return output
|
||||
|
||||
return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
||||
ImageNet.
|
||||
""",
|
||||
RESNET_START_DOCSTRING,
|
||||
)
|
||||
class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel):
|
||||
module_class = FlaxResNetForImageClassificationModule
|
||||
|
||||
|
||||
FLAX_VISION_CLASSIF_DOCSTRING = """
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification
|
||||
>>> from PIL import Image
|
||||
>>> import jax
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
|
||||
>>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
||||
|
||||
>>> inputs = image_processor(images=image, return_tensors="np")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
|
||||
```
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
||||
append_replace_return_docstrings(
|
||||
FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig
|
||||
)
|
@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxResNetForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxResNetModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxResNetPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
228
tests/models/resnet/test_modeling_flax_resnet.py
Normal file
228
tests/models/resnet/test_modeling_flax_resnet.py
Normal file
@ -0,0 +1,228 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import ResNetConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers.utils import cached_property, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.models.resnet.modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class FlaxResNetModelTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
image_size=32,
|
||||
num_channels=3,
|
||||
embeddings_size=10,
|
||||
hidden_sizes=[10, 20, 30, 40],
|
||||
depths=[1, 1, 2, 1],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_act="relu",
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.embeddings_size = embeddings_size
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_act = hidden_act
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.num_stages = len(hidden_sizes)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return ResNetConfig(
|
||||
num_channels=self.num_channels,
|
||||
embeddings_size=self.embeddings_size,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
depths=self.depths,
|
||||
hidden_act=self.hidden_act,
|
||||
num_labels=self.num_labels,
|
||||
image_size=self.image_size,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = FlaxResNetModel(config=config)
|
||||
result = model(pixel_values)
|
||||
|
||||
# Output shape (b, c, h, w)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
|
||||
)
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values):
|
||||
config.num_labels = self.num_labels
|
||||
model = FlaxResNetForImageClassification(config=config)
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxResNetModel, FlaxResNetForImageClassification) if is_flax_available() else ()
|
||||
|
||||
is_encoder_decoder = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_tester = FlaxResNetModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
return
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="ResNet does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ResNet does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_stages = self.model_tester.num_stages
|
||||
self.assertEqual(len(hidden_states), expected_num_stages + 1)
|
||||
|
||||
@unittest.skip(reason="ResNet does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(pixel_values, **kwargs):
|
||||
return model(pixel_values=pixel_values, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxResNetModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="np")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 1000)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = jnp.array([-11.1069, -9.7877, -8.3777])
|
||||
|
||||
self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user