mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TF DeiT implementation (#17806)
* Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Initial TF DeiT implementation * Fix copies naming issues * Fix up + docs * Properly same main layer * Name layers properly * Fixup * Fix import * Fix import * Fix import * Fix weight loading for tests whilst not on hub * Add doc tests and remove to_2tuple * Add back to_2tuple Removing to_2tuple results in many downstream changes needed because of the copies checks * Incorporate updates in Improve vision models #17731 PR * Don't hard code num_channels * Copy PyTorch DeiT embeddings and remove pytorch operations with mask * Fix patch embeddings & tidy up * Update PixelShuffle to move logic into class layer * Update doc strings - remove PT references * Use NHWC format in internal layers * Fix up * Use linear activation layer * Remove unused import * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Move dataclass to top of file * Remove from_pt now weights on hub * Fixup Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
7ea6ccc2b3
commit
8581a798c0
@ -217,7 +217,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| DeBERTa-v2 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DeiT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DeiT | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
@ -69,7 +69,7 @@ Tips:
|
||||
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTFeatureExtractor`] in order to
|
||||
prepare images for the model.
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||
|
||||
|
||||
## DeiTConfig
|
||||
@ -100,3 +100,23 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
|
||||
[[autodoc]] DeiTForImageClassificationWithTeacher
|
||||
- forward
|
||||
|
||||
## TFDeiTModel
|
||||
|
||||
[[autodoc]] TFDeiTModel
|
||||
- call
|
||||
|
||||
## TFDeiTForMaskedImageModeling
|
||||
|
||||
[[autodoc]] TFDeiTForMaskedImageModeling
|
||||
- call
|
||||
|
||||
## TFDeiTForImageClassification
|
||||
|
||||
[[autodoc]] TFDeiTForImageClassification
|
||||
- call
|
||||
|
||||
## TFDeiTForImageClassificationWithTeacher
|
||||
|
||||
[[autodoc]] TFDeiTForImageClassificationWithTeacher
|
||||
- call
|
||||
|
@ -2167,6 +2167,16 @@ else:
|
||||
"TFDebertaV2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.deit"].extend(
|
||||
[
|
||||
"TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDeiTForImageClassification",
|
||||
"TFDeiTForImageClassificationWithTeacher",
|
||||
"TFDeiTForMaskedImageModeling",
|
||||
"TFDeiTModel",
|
||||
"TFDeiTPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.distilbert"].extend(
|
||||
[
|
||||
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4574,6 +4584,14 @@ if TYPE_CHECKING:
|
||||
TFDebertaV2Model,
|
||||
TFDebertaV2PreTrainedModel,
|
||||
)
|
||||
from .models.deit import (
|
||||
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDeiTForImageClassification,
|
||||
TFDeiTForImageClassificationWithTeacher,
|
||||
TFDeiTForMaskedImageModeling,
|
||||
TFDeiTModel,
|
||||
TFDeiTPreTrainedModel,
|
||||
)
|
||||
from .models.distilbert import (
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDistilBertForMaskedLM,
|
||||
|
@ -685,6 +685,33 @@ class TFSemanticSegmenterOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFImageClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of image classification models.
|
||||
|
||||
Args:
|
||||
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
||||
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
|
||||
feature maps) of the model at the output of each stage.
|
||||
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
|
@ -42,6 +42,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-vision", "TFData2VecVisionModel"),
|
||||
("deberta", "TFDebertaModel"),
|
||||
("deberta-v2", "TFDebertaV2Model"),
|
||||
("deit", "TFDeiTModel"),
|
||||
("distilbert", "TFDistilBertModel"),
|
||||
("dpr", "TFDPRQuestionEncoder"),
|
||||
("electra", "TFElectraModel"),
|
||||
@ -166,6 +167,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("deit", "TFDeiTForMaskedImageModeling"),
|
||||
("swin", "TFSwinForMaskedImageModeling"),
|
||||
]
|
||||
)
|
||||
@ -175,6 +177,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Image-classsification
|
||||
("convnext", "TFConvNextForImageClassification"),
|
||||
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
||||
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
|
||||
("regnet", "TFRegNetForImageClassification"),
|
||||
("resnet", "TFResNetForImageClassification"),
|
||||
("swin", "TFSwinForImageClassification"),
|
||||
|
@ -17,7 +17,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"]}
|
||||
@ -45,6 +51,21 @@ else:
|
||||
"DeiTPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_deit"] = [
|
||||
"TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDeiTForImageClassification",
|
||||
"TFDeiTForImageClassificationWithTeacher",
|
||||
"TFDeiTForMaskedImageModeling",
|
||||
"TFDeiTModel",
|
||||
"TFDeiTPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
|
||||
@ -72,6 +93,21 @@ if TYPE_CHECKING:
|
||||
DeiTPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_deit import (
|
||||
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDeiTForImageClassification,
|
||||
TFDeiTForImageClassificationWithTeacher,
|
||||
TFDeiTForMaskedImageModeling,
|
||||
TFDeiTModel,
|
||||
TFDeiTPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
1043
src/transformers/models/deit/modeling_tf_deit.py
Normal file
1043
src/transformers/models/deit/modeling_tf_deit.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -870,6 +870,44 @@ class TFDebertaV2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFDeiTForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDeiTForImageClassificationWithTeacher(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDeiTForMaskedImageModeling(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDeiTModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFDeiTPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
282
tests/models/deit/test_modeling_tf_deit.py
Normal file
282
tests/models/deit/test_modeling_tf_deit.py
Normal file
@ -0,0 +1,282 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow DeiT model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import DeiTConfig
|
||||
from transformers.testing_utils import require_tf, require_vision, slow
|
||||
from transformers.utils import cached_property, is_tf_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
TFDeiTForImageClassification,
|
||||
TFDeiTForImageClassificationWithTeacher,
|
||||
TFDeiTForMaskedImageModeling,
|
||||
TFDeiTModel,
|
||||
)
|
||||
from transformers.models.deit.modeling_tf_deit import TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import DeiTFeatureExtractor
|
||||
|
||||
|
||||
class TFDeiTModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
|
||||
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 2
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return DeiTConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFDeiTModel(config=config)
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
|
||||
model = TFDeiTForMaskedImageModeling(config=config)
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = TFDeiTForMaskedImageModeling(config)
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = TFDeiTForImageClassification(config)
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = TFDeiTForImageClassification(config)
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFDeiTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_tf_common.py, as DeiT does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
TFDeiTModel,
|
||||
TFDeiTForImageClassification,
|
||||
TFDeiTForImageClassificationWithTeacher,
|
||||
TFDeiTForMaskedImageModeling,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDeiTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DeiTConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="DeiT does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_image_modeling(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
# special case for DeiTForImageClassificationWithTeacher model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
|
||||
del inputs_dict["labels"]
|
||||
|
||||
return inputs_dict
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFDeiTModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_vision
|
||||
class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return (
|
||||
DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
|
||||
|
||||
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
@ -26,6 +26,7 @@ src/transformers/models/cvt/modeling_cvt.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_vision.py
|
||||
src/transformers/models/deit/modeling_deit.py
|
||||
src/transformers/models/deit/modeling_tf_deit.py
|
||||
src/transformers/models/detr/modeling_detr.py
|
||||
src/transformers/models/dpt/modeling_dpt.py
|
||||
src/transformers/models/electra/modeling_electra.py
|
||||
|
Loading…
Reference in New Issue
Block a user