mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TF ViT MAE (#16255)
* ported TFViTMAEIntermediate and TFViTMAEOutput. * added TFViTMAEModel and TFViTMAEDecoder. * feat: added a noise argument in the implementation for reproducibility. * feat: vit mae models with an additional noise argument for reproducibility. Co-authored-by: ariG23498 <aritra.born2fly@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
7a9ef8181c
commit
5b40a37bc4
@ -260,7 +260,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| ViTMAE | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ViTMAE | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
|
@ -41,13 +41,16 @@ fine-tuning, one can directly plug in the weights into a [`ViTForImageClassifica
|
||||
- Note that the encoder of MAE is only used to encode the visual patches. The encoded patches are then concatenated with mask tokens, which the decoder (which also
|
||||
consists of Transformer blocks) takes as input. Each mask token is a shared, learned vector that indicates the presence of a missing patch to be predicted. Fixed
|
||||
sin/cos position embeddings are added both to the input of the encoder and the decoder.
|
||||
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> MAE architecture. Taken from the <a href="https://arxiv.org/abs/2111.06377">original paper.</a> </small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/mae).
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [sayakpaul](https://github.com/sayakpaul) and
|
||||
[ariG23498](https://github.com/ariG23498) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/mae).
|
||||
|
||||
|
||||
## ViTMAEConfig
|
||||
|
||||
@ -64,3 +67,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
|
||||
|
||||
[[autodoc]] transformers.ViTMAEForPreTraining
|
||||
- forward
|
||||
|
||||
|
||||
## TFViTMAEModel
|
||||
|
||||
[[autodoc]] TFViTMAEModel
|
||||
- call
|
||||
|
||||
|
||||
## TFViTMAEForPreTraining
|
||||
|
||||
[[autodoc]] transformers.TFViTMAEForPreTraining
|
||||
- call
|
||||
|
@ -2135,6 +2135,13 @@ if is_tf_available():
|
||||
"TFViTPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vit_mae"].extend(
|
||||
[
|
||||
"TFViTMAEForPreTraining",
|
||||
"TFViTMAEModel",
|
||||
"TFViTMAEPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4170,6 +4177,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
|
||||
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
|
||||
from .models.vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
|
||||
from .models.wav2vec2 import (
|
||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFWav2Vec2ForCTC,
|
||||
|
@ -70,6 +70,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot", "TFBlenderbotModel"),
|
||||
("blenderbot-small", "TFBlenderbotSmallModel"),
|
||||
("vit", "TFViTModel"),
|
||||
("vit_mae", "TFViTMAEModel"),
|
||||
("wav2vec2", "TFWav2Vec2Model"),
|
||||
("hubert", "TFHubertModel"),
|
||||
]
|
||||
@ -100,6 +101,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("funnel", "TFFunnelForPreTraining"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
("vit_mae", "TFViTMAEForPreTraining"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,12 @@ if is_torch_available():
|
||||
"ViTMAEPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_vit_mae"] = [
|
||||
"TFViTMAEForPreTraining",
|
||||
"TFViTMAEModel",
|
||||
"TFViTMAEPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
|
||||
@ -46,6 +52,9 @@ if TYPE_CHECKING:
|
||||
ViTMAEPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
1087
src/transformers/models/vit_mae/modeling_tf_vit_mae.py
Normal file
1087
src/transformers/models/vit_mae/modeling_tf_vit_mae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -240,18 +240,21 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
||||
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
||||
|
||||
def random_masking(self, sequence):
|
||||
def random_masking(self, sequence, noise=None):
|
||||
"""
|
||||
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
||||
noise.
|
||||
|
||||
Args:
|
||||
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
|
||||
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
||||
mainly used for testing purposes to control randomness and maintain the reproducibility
|
||||
"""
|
||||
batch_size, seq_length, dim = sequence.shape
|
||||
len_keep = int(seq_length * (1 - self.config.mask_ratio))
|
||||
|
||||
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
||||
if noise is None:
|
||||
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
||||
@ -269,7 +272,7 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
|
||||
return sequence_masked, mask, ids_restore
|
||||
|
||||
def forward(self, pixel_values):
|
||||
def forward(self, pixel_values, noise=None):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
@ -277,7 +280,7 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
embeddings = embeddings + self.position_embeddings[:, 1:, :]
|
||||
|
||||
# masking: length -> length * config.mask_ratio
|
||||
embeddings, mask, ids_restore = self.random_masking(embeddings)
|
||||
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
||||
@ -668,6 +671,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
noise=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@ -709,7 +713,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, mask, ids_restore = self.embeddings(pixel_values)
|
||||
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -910,6 +914,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
noise=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@ -941,6 +946,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values,
|
||||
noise=noise,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -1987,6 +1987,27 @@ class TFViTPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFViTMAEForPreTraining(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFViTMAEModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFViTMAEPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
718
tests/vit_mae/test_modeling_tf_vit_mae.py
Normal file
718
tests/vit_mae/test_modeling_tf_vit_mae.py
Normal file
@ -0,0 +1,718 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow ViTMAE model. """
|
||||
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ViTMAEConfig
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow, torch_device
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFViTMAEForPreTraining, TFViTMAEModel
|
||||
from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ViTFeatureExtractor
|
||||
|
||||
|
||||
class TFViTMAEModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return ViTMAEConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFViTMAEModel(config=config)
|
||||
result = model(pixel_values, training=False)
|
||||
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
|
||||
# (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.image_size)
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
|
||||
|
||||
def create_and_check_for_pretraining(self, config, pixel_values, labels):
|
||||
model = TFViTMAEForPreTraining(config)
|
||||
result = model(pixel_values, training=False)
|
||||
# expected sequence length = num_patches
|
||||
image_size = to_2tuple(self.image_size)
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
expected_seq_len = num_patches
|
||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, pixel_values, labels) = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as ViTMAE does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TFViTMAEModel, TFViTMAEForPreTraining) if is_tf_available() else ()
|
||||
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFViTMAEModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTMAEConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
# ViTMAE does not use inputs_embeds
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
def test_keyword_and_dict_args(self):
|
||||
# make the mask reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs_dict = model(inputs, noise=noise)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
outputs_keywords = model(**inputs_keywords, noise=noise)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
output_keywords = outputs_keywords[0].numpy()
|
||||
|
||||
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
def test_numpy_arrays_inputs(self):
|
||||
# make the mask reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
|
||||
def prepare_numpy_arrays(inputs_dict):
|
||||
inputs_np_dict = {}
|
||||
for k, v in inputs_dict.items():
|
||||
if tf.is_tensor(v):
|
||||
inputs_np_dict[k] = v.numpy()
|
||||
else:
|
||||
inputs_np_dict[k] = np.array(k)
|
||||
|
||||
return inputs_np_dict
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_np = prepare_numpy_arrays(inputs)
|
||||
|
||||
output_for_dict_input = model(inputs_np, noise=noise)
|
||||
output_for_kw_input = model(**inputs_np, noise=noise)
|
||||
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# ViTMAE has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
# make masks reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
pt_noise = torch.from_numpy(noise).to(device=torch_device)
|
||||
tf_noise = tf.constant(noise)
|
||||
|
||||
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
|
||||
|
||||
pt_inputs_dict = {}
|
||||
for name, key in tf_inputs_dict.items():
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
|
||||
return pt_inputs_dict
|
||||
|
||||
def check_outputs(tf_outputs, pt_outputs, model_class, names):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
|
||||
debugging easier and faster.
|
||||
|
||||
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
|
||||
"""
|
||||
|
||||
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
|
||||
if type(tf_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(tf_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
|
||||
check_outputs(tf_output, pt_output, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
|
||||
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
elif isinstance(tf_outputs, tf.Tensor):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
|
||||
tf_outputs = tf_outputs.numpy()
|
||||
if isinstance(tf_outputs, np.float32):
|
||||
tf_outputs = np.array(tf_outputs, dtype=np.float32)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[tf_nans] = 0
|
||||
tf_outputs[tf_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
def check_pt_tf_models(tf_model, pt_model):
|
||||
# we are not preparing a model with labels because of the formation
|
||||
# of the ViT MAE model
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs_dict = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||
}
|
||||
|
||||
# Original test: check without `labels`
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
|
||||
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
if self.has_attentions:
|
||||
config.output_attentions = True
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining outputs loss along with
|
||||
# logits and mask indices. loss and mask indicies are not suitable for integration
|
||||
# with other keras modules.
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# `pixel_values` implies that the input is an image
|
||||
inputs = tf.keras.Input(
|
||||
batch_shape=(
|
||||
3,
|
||||
self.model_tester.num_channels,
|
||||
self.model_tester.image_size,
|
||||
self.model_tester.image_size,
|
||||
),
|
||||
name="pixel_values",
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(inputs)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# `TFViTMAEForPreTraining` outputs are not recommended to be used for
|
||||
# downstream application. This is just to check if the outputs of
|
||||
# `TFViTMAEForPreTraining` can be integrated with other keras modules.
|
||||
if model_class.__name__ == "TFViTMAEForPreTraining":
|
||||
hidden_states = outputs_dict["logits"]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
def test_keras_save_load(self):
|
||||
# make mask reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
tf_main_layer_classes = set(
|
||||
module_member
|
||||
for model_class in self.all_model_classes
|
||||
for module in (import_module(model_class.__module__),)
|
||||
for module_member_name in dir(module)
|
||||
if module_member_name.endswith("MainLayer")
|
||||
# This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
|
||||
and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
|
||||
for module_member in (getattr(module, module_member_name),)
|
||||
if isinstance(module_member, type)
|
||||
and tf.keras.layers.Layer in module_member.__bases__
|
||||
and getattr(module_member, "_keras_serializable", False)
|
||||
)
|
||||
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
noise = tf.convert_to_tensor(noise)
|
||||
inputs_dict.update({"noise": noise})
|
||||
|
||||
for main_layer_class in tf_main_layer_classes:
|
||||
main_layer = main_layer_class(config)
|
||||
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
}
|
||||
|
||||
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
filepath = os.path.join(tmpdirname, "keras_model.h5")
|
||||
model.save(filepath)
|
||||
model = tf.keras.models.load_model(
|
||||
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
|
||||
)
|
||||
assert isinstance(model, tf.keras.Model)
|
||||
after_outputs = model(inputs_dict)
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
def test_save_load(self):
|
||||
# make mask reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model_input = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(model_input, noise=noise)
|
||||
|
||||
if model_class.__name__ == "TFViTMAEModel":
|
||||
out_2 = outputs.last_hidden_state.numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
else:
|
||||
out_2 = outputs.logits.numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
model = tf.keras.models.load_model(saved_model_dir)
|
||||
after_outputs = model(model_input, noise=noise)
|
||||
|
||||
if model_class.__name__ == "TFViTMAEModel":
|
||||
out_1 = after_outputs["last_hidden_state"].numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
else:
|
||||
out_1 = after_outputs["logits"].numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
def test_save_load_config(self):
|
||||
# make mask reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs = model(model_inputs, noise=noise)
|
||||
model_config = model.get_config()
|
||||
# make sure that returned config is jsonifiable, which is required by keras
|
||||
json.dumps(model_config)
|
||||
new_model = model_class.from_config(model.get_config())
|
||||
# make sure it also accepts a normal config
|
||||
_ = model_class.from_config(model.config)
|
||||
_ = new_model(model_inputs) # Build model
|
||||
new_model.set_weights(model.get_weights())
|
||||
after_outputs = new_model(model_inputs, noise=noise)
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
@unittest.skip(
|
||||
reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load
|
||||
to get deterministic results."""
|
||||
)
|
||||
def test_determinism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load""")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
model = TFViTMAEModel.from_pretrained("google/vit-base-patch16-224")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_vision
|
||||
class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs, noise=noise)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.convert_to_tensor([1, 196, 768])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[[-0.0548, -1.7023, -0.9325], [0.3721, -0.5670, -0.2233], [0.8235, -1.3878, -0.3524]]
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
@ -17,13 +17,14 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ViTMAEConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
@ -139,11 +140,7 @@ class ViTMAEModelTester:
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
labels,
|
||||
) = config_and_inputs
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -322,6 +319,153 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
|
||||
# to generate masks during test
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import transformers
|
||||
|
||||
# make masks reproducible
|
||||
np.random.seed(2)
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
num_patches = int((config.image_size // config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
|
||||
pt_noise = torch.from_numpy(noise).to(device=torch_device)
|
||||
tf_noise = tf.constant(noise)
|
||||
|
||||
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs_dict.items():
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
|
||||
return tf_inputs_dict
|
||||
|
||||
def check_outputs(tf_outputs, pt_outputs, model_class, names):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
|
||||
debugging easier and faster.
|
||||
|
||||
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
|
||||
"""
|
||||
|
||||
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
|
||||
if type(tf_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(tf_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
|
||||
check_outputs(tf_output, pt_output, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
|
||||
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
elif isinstance(tf_outputs, tf.Tensor):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
|
||||
tf_outputs = tf_outputs.numpy()
|
||||
if isinstance(tf_outputs, np.float32):
|
||||
tf_outputs = np.array(tf_outputs, dtype=np.float32)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[tf_nans] = 0
|
||||
tf_outputs[tf_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict):
|
||||
# we are not preparing a model with labels because of the formation
|
||||
# of the ViT MAE model
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs_dict = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||
}
|
||||
|
||||
# Original test: check without `labels`
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
|
||||
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
tf_model = tf_model_class(config)
|
||||
pt_model = model_class(config)
|
||||
|
||||
# make sure only tf inputs are forward that actually exist in function args
|
||||
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
|
||||
|
||||
# remove all head masks
|
||||
tf_input_keys.discard("head_mask")
|
||||
tf_input_keys.discard("cross_attn_head_mask")
|
||||
tf_input_keys.discard("decoder_head_mask")
|
||||
|
||||
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = pt_model.to(torch_device)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
|
||||
def test_save_load(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -400,11 +544,8 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
# make random mask reproducible
|
||||
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences,
|
||||
# as they both have fairly different PRNGs (for efficiency reasons).
|
||||
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
|
||||
torch.manual_seed(2)
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
|
||||
@ -412,22 +553,22 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
outputs = model(**inputs, noise=torch.from_numpy(noise))
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 196, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice_cpu = torch.tensor(
|
||||
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]]
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.0548, -1.7023, -0.9325], [0.3721, -0.5670, -0.2233], [0.8235, -1.3878, -0.3524]]
|
||||
)
|
||||
expected_slice_gpu = torch.tensor(
|
||||
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
|
||||
)
|
||||
|
||||
# set expected slice depending on device
|
||||
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user