Add TF DeiT implementation (#17806)

* Initial TF DeiT implementation

* Fix copies naming issues

* Fix up + docs

* Properly same main layer

* Name layers properly

* Initial TF DeiT implementation

* Fix copies naming issues

* Fix up + docs

* Properly same main layer

* Name layers properly

* Fixup

* Fix import

* Fix import

* Fix import

* Fix weight loading for tests whilst not on hub

* Add doc tests and remove to_2tuple

* Add back to_2tuple
Removing to_2tuple results in many downstream changes needed because of the copies checks

* Incorporate updates in Improve vision models #17731 PR

* Don't hard code num_channels

* Copy PyTorch DeiT embeddings and remove pytorch operations with mask

* Fix patch embeddings & tidy up

* Update PixelShuffle to move logic into class layer

* Update doc strings - remove PT references

* Use NHWC format in internal layers

* Fix up

* Use linear activation layer

* Remove unused import

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Move dataclass to top of file

* Remove from_pt now weights on hub

* Fixup

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
This commit is contained in:
amyeroberts 2022-07-13 18:04:08 +01:00 committed by GitHub
parent 7ea6ccc2b3
commit 8581a798c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1471 additions and 3 deletions

View File

@ -217,7 +217,7 @@ Flax), PyTorch, and/or TensorFlow.
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | ✅ | ✅ | ✅ | ❌ |
| Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| DeiT | ❌ | ❌ | ✅ | | ❌ |
| DeiT | ❌ | ❌ | ✅ | | ❌ |
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |

View File

@ -69,7 +69,7 @@ Tips:
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTFeatureExtractor`] in order to
prepare images for the model.
This model was contributed by [nielsr](https://huggingface.co/nielsr).
This model was contributed by [nielsr](https://huggingface.co/nielsr). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts).
## DeiTConfig
@ -100,3 +100,23 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr).
[[autodoc]] DeiTForImageClassificationWithTeacher
- forward
## TFDeiTModel
[[autodoc]] TFDeiTModel
- call
## TFDeiTForMaskedImageModeling
[[autodoc]] TFDeiTForMaskedImageModeling
- call
## TFDeiTForImageClassification
[[autodoc]] TFDeiTForImageClassification
- call
## TFDeiTForImageClassificationWithTeacher
[[autodoc]] TFDeiTForImageClassificationWithTeacher
- call

View File

@ -2167,6 +2167,16 @@ else:
"TFDebertaV2PreTrainedModel",
]
)
_import_structure["models.deit"].extend(
[
"TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDeiTForImageClassification",
"TFDeiTForImageClassificationWithTeacher",
"TFDeiTForMaskedImageModeling",
"TFDeiTModel",
"TFDeiTPreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
[
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4574,6 +4584,14 @@ if TYPE_CHECKING:
TFDebertaV2Model,
TFDebertaV2PreTrainedModel,
)
from .models.deit import (
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDeiTForImageClassification,
TFDeiTForImageClassificationWithTeacher,
TFDeiTForMaskedImageModeling,
TFDeiTModel,
TFDeiTPreTrainedModel,
)
from .models.distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,

View File

@ -685,6 +685,33 @@ class TFSemanticSegmenterOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFMultipleChoiceModelOutput(ModelOutput):
"""

View File

@ -42,6 +42,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("data2vec-vision", "TFData2VecVisionModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("deit", "TFDeiTModel"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("electra", "TFElectraModel"),
@ -166,6 +167,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "TFDeiTForMaskedImageModeling"),
("swin", "TFSwinForMaskedImageModeling"),
]
)
@ -175,6 +177,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("swin", "TFSwinForImageClassification"),

View File

@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"]}
@ -45,6 +51,21 @@ else:
"DeiTPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_deit"] = [
"TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFDeiTForImageClassification",
"TFDeiTForImageClassificationWithTeacher",
"TFDeiTForMaskedImageModeling",
"TFDeiTModel",
"TFDeiTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
@ -72,6 +93,21 @@ if TYPE_CHECKING:
DeiTPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_deit import (
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDeiTForImageClassification,
TFDeiTForImageClassificationWithTeacher,
TFDeiTForMaskedImageModeling,
TFDeiTModel,
TFDeiTPreTrainedModel,
)
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@ -870,6 +870,44 @@ class TFDebertaV2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFDeiTForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFDeiTForImageClassificationWithTeacher(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFDeiTForMaskedImageModeling(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFDeiTModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFDeiTPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow DeiT model. """
import inspect
import unittest
import numpy as np
from transformers import DeiTConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import (
TFDeiTForImageClassification,
TFDeiTForImageClassificationWithTeacher,
TFDeiTForMaskedImageModeling,
TFDeiTModel,
)
from transformers.models.deit.modeling_tf_deit import TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import DeiTFeatureExtractor
class TFDeiTModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 2
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return DeiTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFDeiTModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = TFDeiTForMaskedImageModeling(config=config)
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = TFDeiTForMaskedImageModeling(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFDeiTForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = TFDeiTForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFDeiTModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_tf_common.py, as DeiT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(
TFDeiTModel,
TFDeiTForImageClassification,
TFDeiTForImageClassificationWithTeacher,
TFDeiTForMaskedImageModeling,
)
if is_tf_available()
else ()
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFDeiTModelTester(self)
self.config_tester = ConfigTester(self, config_class=DeiTConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="DeiT does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
# special case for DeiTForImageClassificationWithTeacher model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
del inputs_dict["labels"]
return inputs_dict
@slow
def test_model_from_pretrained(self):
for model_name in TF_DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFDeiTModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class DeiTModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

View File

@ -26,6 +26,7 @@ src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deit/modeling_deit.py
src/transformers/models/deit/modeling_tf_deit.py
src/transformers/models/detr/modeling_detr.py
src/transformers/models/dpt/modeling_dpt.py
src/transformers/models/electra/modeling_electra.py