mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Add TFConvNextModel (#15750)
* feat: initial implementation of convnext in tensorflow. * fix: sample code for the classification model. * chore: added checked for from the classification model. * chore: set bias initializer in the classification head. * chore: updated license terms. * chore: removed ununsed imports * feat: enabled argument during using drop_path. * chore: replaced tf.identity with layers.Activation(linear). * chore: edited default checkpoint. * fix: minor bugs in the initializations. * partial-fix: tf model errors for loading pretrained pt weights. * partial-fix: call method updated * partial-fix: cross loading of weights (4x3 variables to be matched) * chore: removed unneeded comment. * removed playground.py * rebasing * rebasing and removing playground.py. * fix: renaming TFConvNextStage conv and layer norm layers * chore: added initializers and other minor additions. * chore: added initializers and other minor additions. * add: tests for convnext. * fix: integration tester class. * fix: issues mentioned in pr feedback (round 1). * fix: how output_hidden_states arg is propoagated inside the network. * feat: handling of arg for pure cnn models. * chore: added a note on equal contribution in model docs. * rebasing * rebasing and removing playground.py. * feat: encapsulation for the convnext trunk. * Fix variable naming; Test-related corrections; Run make fixup * chore: added Joao as a contributor to convnext. * rebasing * rebasing and removing playground.py. * rebasing * rebasing and removing playground.py. * chore: corrected copyright year and added comment on NHWC. * chore: fixed the black version and ran formatting. * chore: ran make style. * chore: removed from_pt argument from test, ran make style. * rebasing * rebasing and removing playground.py. * rebasing * rebasing and removing playground.py. * fix: tests in the convnext subclass, ran make style. * rebasing * rebasing and removing playground.py. * rebasing * rebasing and removing playground.py. * chore: moved convnext test to the correct location * fix: locations for the test file of convnext. * fix: convnext tests. * chore: applied sgugger's suggestion for dealing w/ output_attentions. * chore: added comments. * chore: applied updated quality enviornment style. * chore: applied formatting with quality enviornment. * chore: revert to the previous tests/test_modeling_common.py. * chore: revert to the original test_modeling_common.py * chore: revert to previous states for test_modeling_tf_common.py and modeling_tf_utils.py * fix: tests for convnext. * chore: removed output_attentions argument from convnext config. * chore: revert to the earlier tf utils. * fix: output shapes of the hidden states * chore: removed unnecessary comment * chore: reverting to the right test_modeling_tf_common.py. * Styling nits Co-authored-by: ariG23498 <aritra.born2fly@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
parent
0b5bf6abef
commit
84eaa6acf5
@ -179,7 +179,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Canine | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ConvNext | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ConvNext | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
@ -37,7 +37,8 @@ alt="drawing" width="600"/>
|
||||
|
||||
<small> ConvNeXT architecture. Taken from the <a href="https://arxiv.org/abs/2201.03545">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [ariG23498](https://github.com/ariG23498),
|
||||
[gante](https://github.com/gante), and [sayakpaul](https://github.com/sayakpaul) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
|
||||
|
||||
## ConvNeXT specific outputs
|
||||
|
||||
@ -63,4 +64,16 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
|
||||
## ConvNextForImageClassification
|
||||
|
||||
[[autodoc]] ConvNextForImageClassification
|
||||
- forward
|
||||
- forward
|
||||
|
||||
|
||||
## TFConvNextModel
|
||||
|
||||
[[autodoc]] TFConvNextModel
|
||||
- call
|
||||
|
||||
|
||||
## TFConvNextForImageClassification
|
||||
|
||||
[[autodoc]] TFConvNextForImageClassification
|
||||
- call
|
@ -1743,6 +1743,13 @@ if is_tf_available():
|
||||
"TFConvBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.convnext"].extend(
|
||||
[
|
||||
"TFConvNextForImageClassification",
|
||||
"TFConvNextModel",
|
||||
"TFConvNextPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ctrl"].extend(
|
||||
[
|
||||
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -3751,6 +3758,7 @@ if TYPE_CHECKING:
|
||||
TFConvBertModel,
|
||||
TFConvBertPreTrainedModel,
|
||||
)
|
||||
from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
|
||||
from .models.ctrl import (
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFCTRLForSequenceClassification,
|
||||
|
@ -311,9 +311,10 @@ def booleans_processing(config, **kwargs):
|
||||
final_booleans = {}
|
||||
|
||||
if tf.executing_eagerly():
|
||||
final_booleans["output_attentions"] = (
|
||||
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
|
||||
)
|
||||
# Pure conv models (such as ConvNext) do not have `output_attentions`
|
||||
final_booleans["output_attentions"] = kwargs.get("output_attentions", None)
|
||||
if final_booleans["output_attentions"] is None:
|
||||
final_booleans["output_attentions"] = config.output_attentions
|
||||
final_booleans["output_hidden_states"] = (
|
||||
kwargs["output_hidden_states"]
|
||||
if kwargs["output_hidden_states"] is not None
|
||||
|
@ -36,6 +36,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
("convnext", "TFConvNextModel"),
|
||||
("led", "TFLEDModel"),
|
||||
("lxmert", "TFLxmertModel"),
|
||||
("mt5", "TFMT5Model"),
|
||||
@ -155,6 +156,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("vit", "TFViTForImageClassification"),
|
||||
("convnext", "TFConvNextForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
|
||||
from ...file_utils import _LazyModule, is_tf_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -36,6 +36,12 @@ if is_torch_available():
|
||||
"ConvNextPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_convnext"] = [
|
||||
"TFConvNextForImageClassification",
|
||||
"TFConvNextModel",
|
||||
"TFConvNextPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
|
||||
@ -51,6 +57,9 @@ if TYPE_CHECKING:
|
||||
ConvNextPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
@ -85,6 +85,7 @@ class ConvNextConfig(PretrainedConfig):
|
||||
is_encoder_decoder=False,
|
||||
layer_scale_init_value=1e-6,
|
||||
drop_path_rate=0.0,
|
||||
image_size=224,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -99,3 +100,4 @@ class ConvNextConfig(PretrainedConfig):
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.image_size = image_size
|
||||
|
618
src/transformers/models/convnext/modeling_tf_convnext.py
Normal file
618
src/transformers/models/convnext/modeling_tf_convnext.py
Normal file
@ -0,0 +1,618 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TF 2.0 ConvNext model."""
|
||||
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_convnext import ConvNextConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "ConvNextConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
|
||||
|
||||
|
||||
class TFConvNextDropPath(tf.keras.layers.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
References:
|
||||
(1) github.com:rwightman/pytorch-image-models
|
||||
"""
|
||||
|
||||
def __init__(self, drop_path, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.drop_path = drop_path
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training:
|
||||
keep_prob = 1 - self.drop_path
|
||||
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
||||
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
||||
random_tensor = tf.floor(random_tensor)
|
||||
return (x / keep_prob) * random_tensor
|
||||
return x
|
||||
|
||||
|
||||
class TFConvNextEmbeddings(tf.keras.layers.Layer):
|
||||
"""This class is comparable to (and inspired by) the SwinEmbeddings class
|
||||
found in src/transformers/models/swin/modeling_swin.py.
|
||||
"""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.patch_embeddings = tf.keras.layers.Conv2D(
|
||||
filters=config.hidden_sizes[0],
|
||||
kernel_size=config.patch_size,
|
||||
strides=config.patch_size,
|
||||
name="patch_embeddings",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
)
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
|
||||
|
||||
def call(self, pixel_values):
|
||||
if isinstance(pixel_values, dict):
|
||||
pixel_values = pixel_values["pixel_values"]
|
||||
|
||||
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
|
||||
# So change the input format from `NCHW` to `NHWC`.
|
||||
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
||||
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
embeddings = self.layernorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class TFConvNextLayer(tf.keras.layers.Layer):
|
||||
"""This corresponds to the `Block` class in the original implementation.
|
||||
|
||||
There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
|
||||
H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
|
||||
|
||||
The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
|
||||
NHWC ordering, we can just apply the operations straight-away without the permutation.
|
||||
|
||||
Args:
|
||||
config ([`ConvNextConfig`]): Model configuration class.
|
||||
dim (`int`): Number of input channels.
|
||||
drop_path (`float`): Stochastic depth rate. Default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, config, dim, drop_path=0.0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.dim = dim
|
||||
self.config = config
|
||||
self.dwconv = tf.keras.layers.Conv2D(
|
||||
filters=dim,
|
||||
kernel_size=7,
|
||||
padding="same",
|
||||
groups=dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="dwconv",
|
||||
) # depthwise conv
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(
|
||||
epsilon=1e-6,
|
||||
name="layernorm",
|
||||
)
|
||||
self.pwconv1 = tf.keras.layers.Dense(
|
||||
units=4 * dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="pwconv1",
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = get_tf_activation(config.hidden_act)
|
||||
self.pwconv2 = tf.keras.layers.Dense(
|
||||
units=dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="pwconv2",
|
||||
)
|
||||
# Using `layers.Activation` instead of `tf.identity` to better control `training`
|
||||
# behaviour.
|
||||
self.drop_path = (
|
||||
TFConvNextDropPath(drop_path, name="drop_path")
|
||||
if drop_path > 0.0
|
||||
else tf.keras.layers.Activation("linear", name="drop_path")
|
||||
)
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
|
||||
self.layer_scale_parameter = (
|
||||
self.add_weight(
|
||||
shape=(self.dim,),
|
||||
initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
||||
trainable=True,
|
||||
name="layer_scale_parameter",
|
||||
)
|
||||
if self.config.layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
input = hidden_states
|
||||
x = self.dwconv(hidden_states)
|
||||
x = self.layernorm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.layer_scale_parameter is not None:
|
||||
x = self.layer_scale_parameter * x
|
||||
|
||||
x = input + self.drop_path(x, training=training)
|
||||
return x
|
||||
|
||||
|
||||
class TFConvNextStage(tf.keras.layers.Layer):
|
||||
"""ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
|
||||
|
||||
Args:
|
||||
config ([`ConvNextConfig`]): Model configuration class.
|
||||
in_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
depth (`int`): Number of residual blocks.
|
||||
drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if in_channels != out_channels or stride > 1:
|
||||
self.downsampling_layer = [
|
||||
tf.keras.layers.LayerNormalization(
|
||||
epsilon=1e-6,
|
||||
name="downsampling_layer.0",
|
||||
),
|
||||
# Inputs to this layer will follow NHWC format since we
|
||||
# transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
|
||||
# layer. All the outputs throughout the model will be in NHWC
|
||||
# from this point on until the output where we again change to
|
||||
# NCHW.
|
||||
tf.keras.layers.Conv2D(
|
||||
filters=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
strides=stride,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="downsampling_layer.1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
self.downsampling_layer = [tf.identity]
|
||||
|
||||
drop_path_rates = drop_path_rates or [0.0] * depth
|
||||
self.layers = [
|
||||
TFConvNextLayer(
|
||||
config,
|
||||
dim=out_channels,
|
||||
drop_path=drop_path_rates[j],
|
||||
name=f"layers.{j}",
|
||||
)
|
||||
for j in range(depth)
|
||||
]
|
||||
|
||||
def call(self, hidden_states):
|
||||
for layer in self.downsampling_layer:
|
||||
hidden_states = layer(hidden_states)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFConvNextEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.stages = []
|
||||
drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
|
||||
cur = 0
|
||||
prev_chs = config.hidden_sizes[0]
|
||||
for i in range(config.num_stages):
|
||||
out_chs = config.hidden_sizes[i]
|
||||
stage = TFConvNextStage(
|
||||
config,
|
||||
in_channels=prev_chs,
|
||||
out_channels=out_chs,
|
||||
stride=2 if i > 0 else 1,
|
||||
depth=config.depths[i],
|
||||
drop_path_rates=drop_path_rates[cur],
|
||||
name=f"stages.{i}",
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += config.depths[i]
|
||||
prev_chs = out_chs
|
||||
|
||||
def call(self, hidden_states, output_hidden_states=False, return_dict=True):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for i, layer_module in enumerate(self.stages):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFConvNextMainLayer(tf.keras.layers.Layer):
|
||||
config_class = ConvNextConfig
|
||||
|
||||
def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFConvNextEncoder(config, name="encoder")
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
||||
# We are setting the `data_format` like so because from here on we will revert to the
|
||||
# NCHW output format
|
||||
self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
|
||||
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
if inputs["pixel_values"] is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
embedding_output = self.embeddings(inputs["pixel_values"], training=inputs["training"])
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
# Change to NCHW output format have uniformity in the modules
|
||||
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
|
||||
pooled_output = self.layernorm(self.pooler(last_hidden_state))
|
||||
|
||||
# Change the other hidden state outputs to NCHW as well
|
||||
if output_hidden_states:
|
||||
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
class TFConvNextPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ConvNextConfig
|
||||
base_model_prefix = "convnext"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
VISION_DUMMY_INPUTS = tf.random.uniform(
|
||||
shape=(
|
||||
3,
|
||||
self.config.num_channels,
|
||||
self.config.image_size,
|
||||
self.config.image_size,
|
||||
),
|
||||
dtype=tf.float32,
|
||||
)
|
||||
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
|
||||
}
|
||||
]
|
||||
)
|
||||
def serving(self, inputs):
|
||||
"""
|
||||
Method used for serving the model.
|
||||
|
||||
Args:
|
||||
inputs (`Dict[str, tf.Tensor]`):
|
||||
The input of the saved model as a dictionary of tensors.
|
||||
"""
|
||||
return self.call(inputs)
|
||||
|
||||
|
||||
CONVNEXT_START_DOCSTRING = r"""
|
||||
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
||||
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
<Tip>
|
||||
|
||||
TF 2.0 models accepts two formats as inputs:
|
||||
|
||||
- having all inputs as keyword arguments (like PyTorch models), or
|
||||
- having all inputs as a list, tuple or dict in the first positional arguments.
|
||||
|
||||
This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
|
||||
tensors in the first argument of the model call function: `model(inputs)`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CONVNEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`ConvNextFeatureExtractor`]. See
|
||||
[`ConvNextFeatureExtractor.__call__`] for details.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
||||
used instead.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
|
||||
in eager mode, in graph mode the value will always be set to True.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ConvNext model outputting raw features without any specific head on top.",
|
||||
CONVNEXT_START_DOCSTRING,
|
||||
)
|
||||
class TFConvNextModel(TFConvNextPreTrainedModel):
|
||||
def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
|
||||
|
||||
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import ConvNextFeatureExtractor, TFConvNextModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224")
|
||||
>>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
```"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
if inputs["pixel_values"] is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
outputs = self.convnext(
|
||||
pixel_values=inputs["pixel_values"],
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (outputs[0],) + outputs[1:]
|
||||
|
||||
return TFBaseModelOutputWithPooling(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
pooler_output=outputs.pooler_output,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
||||
ImageNet.
|
||||
""",
|
||||
CONVNEXT_START_DOCSTRING,
|
||||
)
|
||||
class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.convnext = TFConvNextMainLayer(config, name="convnext")
|
||||
|
||||
# Classifier head
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
units=config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="classifier",
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import ConvNextFeatureExtractor, TFConvNextForImageClassification
|
||||
>>> import tensorflow as tf
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224")
|
||||
>>> model = TFViTForImageClassification.from_pretrained("facebook/convnext-tiny-224")
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
||||
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
||||
```"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
if inputs["pixel_values"] is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
outputs = self.convnext(
|
||||
inputs["pixel_values"],
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
@ -641,6 +641,27 @@ class TFConvBertPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFConvNextForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFConvNextModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFConvNextPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
281
tests/convnext/test_modeling_tf_convnext.py
Normal file
281
tests/convnext/test_modeling_tf_convnext.py
Normal file
@ -0,0 +1,281 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow ConvNext model. """
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import ConvNextConfig
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
|
||||
from transformers.testing_utils import require_tf, require_vision, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFConvNextForImageClassification, TFConvNextModel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ConvNextFeatureExtractor
|
||||
|
||||
|
||||
class TFConvNextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=32,
|
||||
num_channels=3,
|
||||
num_stages=4,
|
||||
hidden_sizes=[10, 20, 30, 40],
|
||||
depths=[2, 2, 3, 2],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.num_stages = num_stages
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return ConvNextConfig(
|
||||
num_channels=self.num_channels,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
depths=self.depths,
|
||||
num_stages=self.num_stages,
|
||||
hidden_act=self.hidden_act,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFConvNextModel(config=config)
|
||||
result = model(pixel_values, training=False)
|
||||
# expected last hidden states: B, C, H // 32, W // 32
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
|
||||
)
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = TFConvNextForImageClassification(config)
|
||||
result = model(pixel_values, labels=labels, training=False)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else ()
|
||||
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFConvNextModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=ConvNextConfig,
|
||||
has_text_modality=False,
|
||||
hidden_size=37,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="ConvNext does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ConvNext does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Model doesn't have attention layers")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_stages = self.model_tester.num_stages
|
||||
self.assertEqual(len(hidden_states), expected_num_stages + 1)
|
||||
|
||||
# ConvNext's feature maps are of shape (batch_size, num_channels, height, width)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# Since ConvNext does not have any attention we need to rewrite this test.
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
all(tf.equal(tuple_object, dict_object)),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_vision
|
||||
class TFConvNextModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return (
|
||||
ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") if is_vision_available() else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant([-0.0260, -0.4739, 0.1911])
|
||||
|
||||
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
|
@ -474,8 +474,8 @@ class TFModelTesterMixin:
|
||||
),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
# TODO: A better way to handle vision models
|
||||
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]:
|
||||
# `pixel_values` implies that the input is an image
|
||||
elif model_class.main_input_name == "pixel_values":
|
||||
inputs = tf.keras.Input(
|
||||
batch_shape=(
|
||||
3,
|
||||
|
Loading…
Reference in New Issue
Block a user