Port IDEFICS to tensorflow (#26870)

* Initial commit

* Just a copy of modeling_idefics.py that will be ported to TF

* - Prepend TF to the name of all classes
- Convert pytorch ops to TF (not all operations are converted yet)

* Add TF imports

* Add autotranslated files

* Add TF classes to model_tf_auto.py

* Add the TF classes in model_doc

* include auto-translated code

* Adopted from auto-translated version

* Add a forgotten super().build

* Add test code for TF version.

* Fix indentation and load pytorch weights for now

* Some fixes. Many tests are still failing but some are passing now.

- I have added TODO's for some of the hacks I made to unblock me
  and I will address them soon
- I have the processing_idefics.py hacked in my view to support TF temporarily

* Add ALL_LAYERNORM_LAYERS to match pytorch

* Revert "Add ALL_LAYERNORM_LAYERS to match pytorch"

This reverts commit 7e0a35119b4d7a6284d04d8c543fba1b29e573c9 as it
is not needed in the tf implementation.

* Fix freeze_relevant_params()

* Some more fixes

* Fix test_attention_outputs

* Add tf stuff to processing_idefics.py

processing_idefics.py supports both pytorch and tf now.

test_processor_idefics.py for pytorch is passing, so i didn't break anything
but still some issues with tf. I also need to add tf tests in
test_processor_idefics.py.

* Pass return_tensors to image processing code and fix test

* Pass return_tensors to the image processor __init__

* Fix several test cases

- Make input to some of the forward pass of type `TFModelInputType`
- Decorate main layer forward pass with `@unpack_inputs`
- Decorate main layer with `@keras_serializable`
- Pass `inputs` to TFIdeficsModel

* Some more fixes forgotten in last commit

* Fix processing code and vision_tf.py

* Fix perceiver bug

* Import from

* Auto-add build() methods + style pass

* Fix build() errors due to `None` being passed as shape to some layers

* Change name in TFIdeficsForVisionText2Text to attribute in IdeficsForVisionText2Text

* Fix pytorch weights load for tf2

There were a lot of `name=` missing in weight initialization code.

* Attempt to fix CI

* Add back accidently removed line

* Remove torch-specific stuff from the TF test file

* make fix-copies, make style, remove autotranslated files

* Fixes to imports/docstrings

* Let's try the from future import in desperation

* Fix the core random_attention_mask fn to match the torch/flax behaviour

* Clean random_attention_mask up correctly

* Remove torch-only test

* Fix loss shape, couple of nits

* make style

* Don't test for OOB embeddings because IDEFICS uses those deliberately

* Fix loss computation to handle masking

* Fix test failures when flattening

* Fix some test failures

- Add cross attention gate which was missing and wasn't being passed arround
- Fix overwriting of image_attention_mask due to hack I had for dummy inputs

* Add a proper stateless scaled_dot_product_attention

* make style

* Adding missing attribute from the PyTorch version

* Small cleanups to decoupledlinearlayer in case that helps

* Pass epsilon to LayerNormalization

* Attemp to fix pytorch weight cross-loading for TFIdeficsEmbedding

* Fix a bug in TFIdeficsGatedCrossAttentionLayer

* Patching up build() methods

* Constant self.inv_freq

* Constant self.inv_freq

* First working version

The TF implementation works now, there was a bug in the TFIdeficsDecoupledLinear
where the weights were mis-intialized (in_features,out_features)
when it should be: (out_features, in_features)

I have tested this so far with tiny-random and idefics-9b-instruct
and gives correct output.

I also dumped the final outputs for both pytorch and TF
and they are identical.

* Fix some test failures

* remove print statement

* Fix return_tensors

* Fix CI test failure check_code_quality

* Attempt to fix CI failures by running `make fixup`

The hardcoded IDs in test_modeling_tf_idefics.py are for the integration
test and makes that file unreadable and should probably be moved to a seperate file.

* Attempt to fix tests_pr_documentation_tests

* Fix a test failure in test_image_processing_idefics.py

* Fix test test_pt_tf_model_equivalence

* Fix a few failures

* Tiny fix

* Some minor fixes

* Remove a duplicate test

* Override a few test failures for IDEFICS

- `test_keras_save_load` is passing now
- `test_compile_tf_model` is still failing

* Fix processing_idefics.py after rebase

* Guard import keras with is_tf_available

* fix check code quality

* fix check code quality

* Minor fixes

* Skip test_save_load temporarily

This test passed on my local box but fails on the CI, skipping
for now to see if there are other remaining failures on the CI.

* Run `ruff format tests src utils`

* Fix last failing test, `test_compile_tf_model`

* Add fixes for vision_tf.py

I forgot to add this file in last commit.

* Minor fixes

* Replace "<<<" with "<<" for doc tests

IDEFICS-9B is too big for doctest runner, so don't run it there

* Make code more readable

* Fix bug after code review

I added a layer_norm_eps to IdeficsConfig but I don't even need it
since the vision config has a layer_norm_eps.

* Fix after code review

Use original code tokenizer.convert_tokens_to_ids

* Keep PyTorch as the default return_tensors

* Fixes to modeling_tf after code review

* Fixes from code review

- Remove all references of `TF_IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST`
- Pass 1e-5 to LayerNormalization in perceiver

* Run ruff

* Undo a change

* Refactor processing code after Matt's suggestion

* Remove TODO's that aren't needed anymore

* For pytorch, Use original pytorch processing code from main

Since this PR is a TF port it shouldn't make any modifications
to pytorch IDEFICS code. This changes undo's the pytorch processing
modifications I made and uses original code from main.

* Update tests/models/idefics/test_modeling_idefics.py

* Update tests/models/idefics/test_modeling_tf_idefics.py

* Add missing imports for is_pt_tf_cross_test

* [DO NOT MERGE]: This is a commit for debugging and will be reverted

The cross test `test_pt_tf_model_equivalence` passes locally but
fails when running on the CI. This commit is to help debug that
and will be reverted.

* Revert "[DO NOT MERGE]: This is a commit for debugging and will be reverted"

This reverts commit 8f0d709ec5bd46685fb0b4259d914ffee794875b.

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 998cc38b8c3d313bf5e5eb55a7f5b7b881897b89.

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 1c695ac4219c4ae4d39b330b01744dc27deb7dd4.

* Don't skip test_save_load

IIRC test_save_load was also failing on the CI but not on my local
box, it might be easier to debug that on the CI first than the cross tests

* Debugging commit, will be reverted

* Revert "Debugging commit, will be reverted"

This reverts commit 8eafc8e41e20c4e95a3a90834f06a6e9f445e2d5.

* Override `test_save_load` and push model to save

Maybe this will help me repro this weird bug

* pass my repo_id

* add endpoint

* Pass a temp (write) token just for this CI

* Undo last few commits, still pushing to hub for model debugging

The issue seems to be with save_pretrained(),  when I looked at the model saved
from the CI test failure it is basically empty and has no weights.
`self.save_weights(..)` seems to be failing in save_pretrained but needs
more debugging

* Add logging to modeling tf utils, will be reverted just for debugging

* Debugging, will revert

* Revert "Debugging, will revert"

This reverts commit 9d0d3075fb7c82d8cde3a5c76bc8f3876c5c55d3.

* Revert "Add logging to modeling tf utils, will be reverted just for debugging"

This reverts commit 774b6b7b1c17b3ce5d7634ade768f2f686cee617.

* Remove `test_save_load`

The CI failures are gone after my latest rebase, no idea why
but I was still saving the model to my hub on HF and the tf_model.h5
file now has everything.

* Run make fix-copies

* Run ruff format tests src utils

* Debugging commit, will be reverted

* Run ruff, also trigger CI run

* Run ruff again

* Undo debugging commit

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Alazar 2024-05-13 17:59:46 +03:00 committed by GitHub
parent de2f722172
commit 94306352f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 3391 additions and 48 deletions

View File

@ -160,7 +160,7 @@ Flax), PyTorch, and/or TensorFlow.
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |

View File

@ -52,6 +52,16 @@ To train a new IDEFICS model from scratch use the m4 codebase (a link will be pr
[[autodoc]] IdeficsForVisionText2Text
- forward
## TFIdeficsModel
[[autodoc]] TFIdeficsModel
- call
## TFIdeficsForVisionText2Text
[[autodoc]] TFIdeficsForVisionText2Text
- call
## IdeficsImageProcessor
[[autodoc]] IdeficsImageProcessor

View File

@ -3862,6 +3862,15 @@ else:
"TFHubertPreTrainedModel",
]
)
_import_structure["models.idefics"].extend(
[
"TFIdeficsForVisionText2Text",
"TFIdeficsModel",
"TFIdeficsPreTrainedModel",
]
)
_import_structure["models.layoutlm"].extend(
[
"TFLayoutLMForMaskedLM",
@ -7905,6 +7914,11 @@ if TYPE_CHECKING:
TFHubertModel,
TFHubertPreTrainedModel,
)
from .models.idefics import (
TFIdeficsForVisionText2Text,
TFIdeficsModel,
TFIdeficsPreTrainedModel,
)
from .models.layoutlm import (
TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,

View File

@ -58,6 +58,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("gptj", "TFGPTJModel"),
("groupvit", "TFGroupViTModel"),
("hubert", "TFHubertModel"),
("idefics", "TFIdeficsModel"),
("layoutlm", "TFLayoutLMModel"),
("layoutlmv3", "TFLayoutLMv3Model"),
("led", "TFLEDModel"),
@ -112,6 +113,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("funnel", "TFFunnelForPreTraining"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("idefics", "TFIdeficsForVisionText2Text"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"),

View File

@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {"configuration_idefics": ["IdeficsConfig"]}
@ -39,6 +45,17 @@ else:
]
_import_structure["processing_idefics"] = ["IdeficsProcessor"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_idefics"] = [
"TFIdeficsForVisionText2Text",
"TFIdeficsModel",
"TFIdeficsPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_idefics import IdeficsConfig
@ -64,6 +81,17 @@ if TYPE_CHECKING:
)
from .processing_idefics import IdeficsProcessor
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_idefics import (
TFIdeficsForVisionText2Text,
TFIdeficsModel,
TFIdeficsPreTrainedModel,
)
else:
import sys

View File

@ -92,8 +92,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
transform: Callable = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs,
) -> TensorType.PYTORCH:
) -> TensorType:
"""
Preprocess a batch of images.
@ -162,7 +163,6 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = [self.rescale(image=image, scale=1 / 255) for image in images]
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"]
return images

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,194 @@
# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
#
# MIT License
#
# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
References:
- DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
- Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
"""
from typing import Optional, Tuple
import tensorflow as tf
from ...modeling_tf_utils import shape_list
from .configuration_idefics import IdeficsConfig
class TFIdeficsPerceiverResampler(tf.keras.layers.Layer):
def __init__(
self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs
) -> None:
"""
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
Args:
config (`IdeficsConfig`): config object
embed_dim (`int`): The size of each embedding vector
depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
head_dim (`int`): Dimensionality of each head projection in the Transformer block.
n_latents (`int`):
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
"""
super().__init__(**kwargs)
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
self.intermediate_dim = (
self.embed_dim * 4
if not hasattr(config.vision_config, "embed_dim")
else config.vision_config.embed_dim * 4
)
# Create Transformer Blocks
self.blocks = []
for i in range(depth):
self.blocks.append(
[
TFIdeficsPerceiverAttention(
self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0"
),
TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"),
]
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
def build(self, input_shape):
# Create Latents for Perceiver
self.latents = self.add_weight(
shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents"
)
super().build(input_shape)
def call(self, context: tf.Tensor) -> tf.Tensor:
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
# tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
latents = tf.expand_dims(self.latents, axis=0)
latents = tf.tile(latents, [tf.shape(context)[0], 1, 1])
# Feed through Perceiver Attention blocks...
for attn, ff in self.blocks:
latents = attn(context, latents) + latents
latents = ff(latents) + latents
return self.layer_norm(latents)
class TFIdeficsPerceiverAttention(tf.keras.layers.Layer):
def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__(**kwargs)
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
self.qk_layer_norms = qk_layer_norms
# Normalization & Scaling
self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm")
self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm")
if self.qk_layer_norms:
self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm")
self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm")
self.qk_scale = self.head_dim**-0.5
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj")
self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj")
self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj")
self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj")
def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor:
"""
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
Args:
context (`tf.Tensor`):
Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
latents (`tf.Tensor`):
Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
Returns:
`tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
from context.
"""
context = self.context_layer_norm(context)
latents = self.latents_layer_norm(latents)
batch_size, seq_length, embed_dim = shape_list(context)
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
q = self.q_proj(latents)
k = self.k_proj(tf.concat([context, latents], axis=-2))
v = self.v_proj(tf.concat([context, latents], axis=-2))
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
q, k, v = [
tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3])
for x in (q, k, v)
]
if self.qk_layer_norms:
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True)
attn = tf.nn.softmax(stabilized_scores, axis=-1)
# Attend & project back to output...
resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v)
return self.output_proj(
tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim))
)
class TFIdeficsMLP(tf.keras.layers.Layer):
def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs):
"""Simple MLP block with intermediate_size and embedding size"""
super().__init__(**kwargs)
self.embed_dim = config.vision_config.embed_dim
self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln")
self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc")
self.act = tf.keras.layers.ReLU(name="act")
self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj")
def call(self, hidden_states: Optional[Tuple[tf.Tensor]]) -> tf.Tensor:
hidden_states = self.ln(hidden_states)
hidden_states = self.fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states

View File

@ -22,34 +22,53 @@ from urllib.parse import urlparse
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_available
from ...utils import is_tf_available, is_torch_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
IMAGE_TOKEN = "<image>"
# copied from m4.training.packing
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
# If any of images index are more than num_classes, set them to -1.
# Words after the max number of images allowed have been seen don't attend on anything
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
# Set elements >= num_classes to -1
if num_classes != -1:
incremental_mask[incremental_mask >= num_classes] = -1
if return_tensors == "pt":
incremental_mask[incremental_mask >= num_classes] = -1
elif return_tensors == "tf":
incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask)
# Create mask for negative values
if return_tensors == "pt":
negatives = incremental_mask == -1
incremental_mask[negatives] = 0
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
attn_mask[negatives, :] = 0
elif return_tensors == "tf":
negatives = tf.equal(incremental_mask, -1)
incremental_mask = tf.where(negatives, 0, incremental_mask)
attn_mask = tf.one_hot(incremental_mask, depth=num_classes)
# Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1]
negatives_expanded = tf.expand_dims(negatives, -1)
attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask)
negatives = incremental_mask == -1
incremental_mask[negatives] = 0
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
attn_mask[negatives, :] = 0
return attn_mask
# copied from m4.training.packing
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors):
if return_tensors == "pt":
return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer)
elif return_tensors == "tf":
return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer)
def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer):
image_attention_mask = torch.full_like(input_ids, fill_value=-1)
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
@ -96,6 +115,39 @@ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
return image_attention_mask, next_image_attention_mask
def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer):
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
eod_token_id = tokenizer.eos_token_id
batch_size = tf.shape(input_ids)[0]
image_attention_mask = tf.fill(tf.shape(input_ids), -1)
next_image_attention_mask = tf.fill(tf.shape(input_ids), -1)
for batch_idx in range(batch_size):
count = -1
seen_eod = False
seq_length = tf.shape(input_ids)[1]
for idx in range(seq_length - 1, -1, -1):
token_id = input_ids[batch_idx, idx].numpy()
if token_id == image_token_id:
count += 1
indices = [[batch_idx, idx]]
updates = [count]
image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates)
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
elif token_id == eod_token_id and not seen_eod:
seen_eod = True
count = 0
indices = [[batch_idx, idx]]
updates = [count]
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
if seen_eod and token_id != eod_token_id:
indices = [[batch_idx, idx]]
updates = [-1]
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
return image_attention_mask, next_image_attention_mask
def is_url(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url"""
@ -156,7 +208,7 @@ class IdeficsProcessor(ProcessorMixin):
add_eos_token=False,
add_end_of_utterance_token=None,
debug=False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_tensors="pt",
) -> BatchEncoding:
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
the model was trained on and prepares the image pixel values for the model to process.
@ -268,7 +320,6 @@ class IdeficsProcessor(ProcessorMixin):
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
# turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts):
prompts = [prompts]
@ -322,7 +373,7 @@ class IdeficsProcessor(ProcessorMixin):
if debug is True:
print(f"{full_text=}")
image_objects = self.image_processor(image_objects, transform=transform)
image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
all_prompts.append(full_text)
all_images.append(image_objects)
@ -345,39 +396,72 @@ class IdeficsProcessor(ProcessorMixin):
output_input_ids = []
output_images = []
output_attention_masks = []
for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text
image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, max_num_images)
current_images = images[:local_max_num_images]
if len(current_images) > 0:
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images
if return_tensors == "pt":
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images
elif return_tensors == "tf":
# Assuming current_images is a TensorFlow tensor
# Get the shape of current_images, excluding the first dimension
image_shape = tf.shape(current_images)[1:]
# Create a shape for the padded_image_tensor
padded_shape = tf.concat([[max_num_images], image_shape], axis=0)
# Create the padded_image_tensor of zeros
padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype)
# Get the number of images (assuming current_images has shape [num_images, height, width, channels])
num_images = tf.shape(current_images)[0]
# Update the padded_image_tensor with the values from current_images
indices = tf.reshape(tf.range(num_images), (-1, 1))
updates = current_images
padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates)
else:
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
if return_tensors == "pt":
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
elif return_tensors == "tf":
padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims))
output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(padded_input_ids))
output_attention_masks.append(torch.tensor(attention_mask))
if return_tensors == "pt":
output_input_ids.append(torch.tensor(padded_input_ids))
output_attention_masks.append(torch.tensor(attention_mask))
elif return_tensors == "tf":
output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32))
output_attention_masks.append(attention_mask)
output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks)
if return_tensors == "pt":
output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks)
elif return_tensors == "tf":
output_input_ids = tf.stack(output_input_ids)
output_images = tf.stack(output_images)
output_attention_masks = tf.stack(output_attention_masks)
if at_least_one_image:
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer)
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
output_input_ids, self.tokenizer, return_tensors
)
image_attention_mask = incremental_to_binary_attention_mask(
image_attention_mask, num_classes=max_num_images
image_attention_mask, return_tensors, num_classes=max_num_images
)
else:
# in full language mode we set the image mask to all-0s
image_attention_mask = torch.zeros(
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
)
if return_tensors == "pt":
image_attention_mask = torch.zeros(
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
)
elif return_tensors == "tf":
image_attention_mask = tf.zeros(
(output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool
)
return BatchFeature(
data={
"input_ids": output_input_ids,

View File

@ -0,0 +1,573 @@
# coding=utf-8
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
from ...modeling_tf_utils import TFPreTrainedModel, shape_list
from ...tf_utils import flatten
from ...utils import ModelOutput, logging
from .configuration_idefics import IdeficsVisionConfig
logger = logging.get_logger(__name__)
@dataclass
class TFIdeficsVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[tf.Tensor] = None
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: IdeficsVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = tf.keras.layers.Conv2D(
filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
use_bias=False,
padding="valid",
data_format="channels_last",
name="patch_embedding",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = tf.keras.layers.Embedding(
self.num_positions, self.embed_dim, name="position_embedding"
)
# self.position_ids = tf.range(self.num_positions)[tf.newaxis, :]
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
num_patches = shape_list(embeddings)[1] - 1
pos_embed = self.position_embedding(self.position_ids)
num_positions = shape_list(pos_embed)[1] - 1
if num_patches == num_positions and height == width:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
embed_dim = shape_list(embeddings)[-1]
num_h_patches = height // self.config.patch_size
num_w_patches = width // self.config.patch_size
num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
sqrt_num_positions = math.sqrt(float(num_positions))
patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim))
scale_height = num_h_patches / sqrt_num_positions
scale_width = num_w_patches / sqrt_num_positions
original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32)
original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32)
# Apply scaling
new_height = tf.cast(original_height * scale_height, tf.int32)
new_width = tf.cast(original_width * scale_width, tf.int32)
patch_pos_embed = tf.image.resize(
patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC
)
if (
int(num_h_patches) != shape_list(patch_pos_embed)[-3]
or int(num_w_patches) != shape_list(patch_pos_embed)[-2]
):
raise ValueError(
f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})"
)
patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim))
return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1)
def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor:
# Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is
# transpose it to change it to NHWC. We don't care to transpose it back because
# the Conv2D layer is only hit once for each query
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
batch_size, height, width, num_channels = shape_list(pixel_values)
if not interpolate_pos_encoding:
if height != self.image_size or width != self.image_size:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
)
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
# Change the 2D spatial dimensions to a single temporal dimension.
# shape = (batch_size, num_patches, out_channels=embed_dim)
patch_embeds = flatten(patch_embeds, 1, 2)
class_embeds = tf.broadcast_to(
self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim]
)
embeddings = tf.concat([class_embeds, patch_embeds], axis=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def build(self, input_shape=None):
if self.built:
return
self.built = True
self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :]
self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding")
if getattr(self, "patch_embedding", None) is not None:
with tf.name_scope(self.patch_embedding.name):
self.patch_embedding.build([None, None, None, self.config.num_channels])
if getattr(self, "position_embedding", None) is not None:
with tf.name_scope(self.position_embedding.name):
self.position_embedding.build(None)
class TFIdeficsVisionAttention(tf.keras.layers.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj")
self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj")
self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj")
self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
causal_attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = shape_list(hidden_states)
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
key_states = tf.reshape(key_states, proj_shape)
value_states = tf.reshape(value_states, proj_shape)
src_len = shape_list(key_states)[1]
attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
tf.shape(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}",
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]:
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {shape_list(causal_attention_mask)}"
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
if attention_mask is not None:
if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]:
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}"
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len))
else:
attn_weights_reshaped = None
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
attn_output = tf.linalg.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
tf.shape(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}",
)
attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build((self.embed_dim, self.embed_dim))
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build((self.embed_dim, self.embed_dim))
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build((self.embed_dim, self.embed_dim))
if getattr(self, "out_proj", None) is not None:
with tf.name_scope(self.out_proj.name):
self.out_proj.build((self.embed_dim, self.embed_dim))
class TFIdeficsVisionMLP(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.activation_fn = get_tf_activation(config.hidden_act)
self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1")
self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2")
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build(self.config.hidden_size)
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build(self.config.intermediate_size)
class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: IdeficsVisionConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.hidden_size
self.self_attn = TFIdeficsVisionAttention(config, name="self_attn")
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
self.mlp = TFIdeficsVisionMLP(config, name="mlp")
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
causal_attention_mask: tf.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[tf.Tensor]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layer_norm1", None) is not None:
with tf.name_scope(self.layer_norm1.name):
self.layer_norm1.build([None, None, self.embed_dim])
if getattr(self, "layer_norm2", None) is not None:
with tf.name_scope(self.layer_norm2.name):
self.layer_norm2.build([None, None, self.embed_dim])
class TFIdeficsVisionEncoder(tf.keras.layers.Layer):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`TFIdeficsVisionEncoderLayer`].
Args:
config: IdeficsVisionConfig
"""
def __init__(self, config: IdeficsVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layers = [
TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
]
self.gradient_checkpointing = False
def call(
self,
inputs_embeds,
attention_mask: Optional[tf.Tensor] = None,
causal_attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutput]:
r"""
Args:
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = tf.recompute_grad(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
causal_attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "layers", None) is not None:
for layer in self.layers:
with tf.name_scope(layer.name):
layer.build(None)
class TFIdeficsVisionTransformer(TFPreTrainedModel):
def __init__(self, config: IdeficsVisionConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.embed_dim = config.hidden_size
self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings")
self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm")
self.encoder = TFIdeficsVisionEncoder(config, name="encoder")
self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[Tuple, TFBaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "embeddings", None) is not None:
with tf.name_scope(self.embeddings.name):
self.embeddings.build(None)
if getattr(self, "pre_layrnorm", None) is not None:
with tf.name_scope(self.pre_layrnorm.name):
self.pre_layrnorm.build([None, None, self.embed_dim])
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
if getattr(self, "post_layernorm", None) is not None:
with tf.name_scope(self.post_layernorm.name):
self.post_layernorm.build([None, self.embed_dim])

View File

@ -104,6 +104,33 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
return outputs
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None
):
"""TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
if dropout_p != 0.0:
raise ValueError(
"Dropout is not supported in this implementation - file an issue "
"with Transformers and ping @Rocketknight1 if you need it for a port!"
)
if is_causal and attn_mask is not None:
raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
if is_causal:
attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
# Convert boolean mask to a negative logit bias
attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
if scale is None:
scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
logits *= scale # scale by 1/sqrt(key_dim)
if attn_mask is not None:
logits += attn_mask
probs = tf.nn.softmax(logits)
return probs @ value
def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF

View File

@ -1542,6 +1542,27 @@ class TFHubertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFIdeficsForVisionText2Text(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFIdeficsModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFIdeficsPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFLayoutLMForMaskedLM(metaclass=DummyObject):
_backends = ["tf"]

View File

@ -152,7 +152,7 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# they both do the same
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
image_processor = self.image_processing_class(**self.image_processor_dict)
image_processor = self.image_processing_class(**self.image_processor_dict, return_tensors="pt")
print(image_inputs)
@ -181,8 +181,8 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
]
)
pixel_values_transform_implied = image_processor(image_inputs, transform=None)
pixel_values_transform_supplied = image_processor(image_inputs, transform=transform)
pixel_values_transform_implied = image_processor(image_inputs, transform=None, return_tensors="pt")
pixel_values_transform_supplied = image_processor(image_inputs, transform=transform, return_tensors="pt")
torch.testing.assert_close(pixel_values_transform_implied, pixel_values_transform_supplied, rtol=0.0, atol=0.0)

View File

@ -21,6 +21,7 @@ from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
is_pt_tf_cross_test,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
@ -559,6 +560,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
check_hidden_states_output(inputs_dict, config, model_class)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
self.has_attentions = False
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
@slow
def test_model_from_pretrained(self):
model_name = "HuggingFaceM4/idefics-9b"

View File

@ -0,0 +1,565 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TF Idefics model. """
import os
import tempfile
import unittest
from importlib import import_module
from transformers import IdeficsConfig, is_tf_available, is_vision_available
from transformers.testing_utils import TestCasePlus, is_pt_tf_cross_test, require_tf, require_vision, slow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import IdeficsProcessor, TFIdeficsForVisionText2Text, TFIdeficsModel
from transformers.modeling_tf_utils import keras
from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
if is_vision_available():
from PIL import Image
IDEFICS_TINY_RANDOM_MODEL = "HuggingFaceM4/tiny-random-idefics"
class IdeficsModelTester:
def __init__(
self,
parent,
batch_size=1,
seq_length=7,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None,
modality_type_vocab_size=2,
vision_embed_dim=32,
vision_patch_size=2,
vision_image_size=30,
vision_num_attention_heads=4,
vision_num_hidden_layers=5,
vision_intermediate_size=37,
perceiver_qk_layer_norms_perceiver=False,
perceiver_resampler_depth=2,
perceiver_resampler_head_dim=8,
perceiver_resampler_n_heads=2,
perceiver_resampler_n_latents=16,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
self.modality_type_vocab_size = modality_type_vocab_size
self.vision_embed_dim = vision_embed_dim
self.vision_patch_size = vision_patch_size
self.vision_image_size = vision_image_size
self.vision_num_attention_heads = vision_num_attention_heads
self.vision_num_hidden_layers = vision_num_hidden_layers
self.vision_intermediate_size = vision_intermediate_size
self.vision_config = IdeficsVisionConfig(
embed_dim=self.vision_embed_dim,
patch_size=self.vision_patch_size,
image_size=self.vision_image_size,
num_attention_heads=self.vision_num_attention_heads,
num_hidden_layers=self.vision_num_hidden_layers,
intermediate_size=self.vision_intermediate_size,
)
self.perceiver_qk_layer_norms_perceiver = perceiver_qk_layer_norms_perceiver
self.perceiver_resampler_depth = perceiver_resampler_depth
self.perceiver_resampler_head_dim = perceiver_resampler_head_dim
self.perceiver_resampler_n_heads = perceiver_resampler_n_heads
self.perceiver_resampler_n_latents = perceiver_resampler_n_latents
self.perceiver_config = IdeficsPerceiverConfig(
qk_layer_norms_perceiver=self.perceiver_qk_layer_norms_perceiver,
resampler_depth=self.perceiver_resampler_depth,
resampler_head_dim=self.perceiver_resampler_head_dim,
resampler_n_heads=self.perceiver_resampler_n_heads,
resampler_n_latents=self.perceiver_resampler_n_latents,
)
# we set the expected sequence length (which is used in several tests)
# this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
pixel_values = floats_tensor(
[
self.batch_size,
num_images,
self.num_channels,
self.image_size + image_expansion,
self.image_size + image_expansion,
]
)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
config = self.get_config()
return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding)
def get_config(self):
return IdeficsConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
num_labels=self.num_labels,
modality_type_vocab_size=self.modality_type_vocab_size,
vision_config=self.vision_config,
)
def create_and_check_model(
self,
config,
input_ids,
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
):
model = TFIdeficsModel(config=config)
result = model(
input_ids,
attention_mask=input_mask,
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
)
def create_and_check_model_gen(
self,
config,
input_ids,
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
):
model = TFIdeficsForVisionText2Text(config)
model.generate(
input_ids,
attention_mask=input_mask,
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
max_length=self.seq_length + 2,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
pixel_values,
image_attention_mask,
interpolate_pos_encoding,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": input_mask,
"pixel_values": pixel_values,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}
return config, inputs_dict
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_tf
class TFIdeficsModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFIdeficsModel, TFIdeficsForVisionText2Text) if is_tf_available() else ()
pipeline_model_mapping = {"feature-extraction": TFIdeficsModel} if is_tf_available() else {}
test_pruning = False
test_headmasking = False
test_onnx = False
test_resize_embeddings = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# XXX: IdeficsForVisionText2TextTest has no MODEL_FOR group yet, but it should be the same
# as MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, so for now manually changing to do the right thing
# as super won't do it
if return_labels:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int64
)
return inputs_dict
def test_model_outputs_equivalence(self):
try:
orig = self.all_model_classes
# IdeficsModel.forward doesn't have labels input arg - only IdeficsForVisionText2Text does
self.all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else ()
super().test_model_outputs_equivalence()
finally:
self.all_model_classes = orig
def setUp(self):
self.model_tester = IdeficsModelTester(self)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=False, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=False, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_image_pos_embeddings_interpolation_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_with_image_pos_embeddings_interpolation_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=0
)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_generate_with_image_pos_embeddings_interpolation_single_image(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=1, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model_gen(*config_and_inputs)
def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(
num_images=2, interpolate_pos_encoding=True, image_expansion=2
)
self.model_tester.create_and_check_model_gen(*config_and_inputs)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="IDEFICS uses out-of-bounds embeddings deliberately.")
def test_embeddings_out_of_bounds_raise_exception(self):
pass
@unittest.skip(reason="IDEFICS attention weights are not extracted in scaled_dot_product_attention")
def test_prepare_serving_output(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
# IDEFICS does not support outputting attention score becuase it uses SDPA under the hood
self.assertTrue(attentions[0] is None)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
# IDEFICS does not support outputting attention score becuase it uses SDPA under the hood
self.assertTrue(self_attentions[0] is None)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
self.has_attentions = False
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
tf_main_layer_classes = {
module_member
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type)
and keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
}
for main_layer_class in tf_main_layer_classes:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: keras.Input(tensor.shape[1:], dtype=tensor.dtype, batch_size=2)
for name, tensor in inputs_dict.items()
if tf.is_tensor(tensor)
}
model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
outputs = model(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "keras_model.h5")
model.save(filepath)
model = keras.models.load_model(filepath, custom_objects={main_layer_class.__name__: main_layer_class})
assert isinstance(model, keras.Model)
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
@unittest.skip(reason="IDEFICS test_keras_fit testing done in TFIdeficsForVisionText2TextTest")
def test_keras_fit(self):
pass
@slow
def test_model_from_pretrained(self):
model = TFIdeficsModel.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True)
self.assertIsNotNone(model)
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
def test_saved_model_creation(self):
pass
@unittest.skip(reason="""IDEFICS loss computation not implemented yet""")
def test_loss_computation(self):
pass
@require_tf
class TFIdeficsForVisionText2TextTest(TFIdeficsModelTest, unittest.TestCase):
all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else ()
test_resize_embeddings = False
def setUp(self):
self.model_tester = IdeficsModelTester(
self,
modality_type_vocab_size=3,
)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
@unittest.skip("We only test the model that takes in multiple images")
def test_model(self):
pass
@unittest.skip("We only test the model that takes in multiple images")
def test_for_token_classification(self):
pass
@unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="""IDEFICS loss computation not implemented yet""")
def test_loss_computation(self):
pass
@slow
def test_keras_fit(self):
super().test_keras_fit()
# Below is the expected output for the integration test TFIdeficsModelIntegrationTest.
# Since we are using tiny-random to be able to fit it on the CI GPU,it is better to assert on the
# ids because the generated text is gibberish
# fmt: off
EXPECTED_GENERATED_IDS = [[0, 0, 1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 26361, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 5916, 14383, 1033, 12358, 10536, 21834, 10447, 21201, 18102, 16886, 8875, 25388, 25914, 28304, 8558, 31048, 1322, 25952, 189, 31600, 3600, 12824, 7045, 28090, 20228, 32001, 5385, 29186, 2165, 11822, 13825, 23077, 7883, 22504, 2078, 18893, 2179, 10556, 9515, 7672, 3491, 12403, 5398, 27299, 6463, 16349, 23037, 28956, 16960, 22664, 7724, 17587, 17424, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 29996, 313, 14502, 3241, 13618, 32001, 5385, 29186, 2165, 11822, 13825, 19934, 4875, 27142, 3230, 2709, 28054, 3270, 19148, 10917, 1060, 26443, 12259, 1347, 28482, 3830, 25519, 199, 12782, 9144, 12289, 1142, 18400, 21390, 19129, 7292, 28430, 24711, 5551, 30349, 30533, 13271, 17697, 4982, 8713, 5380, 17869, 12490, 5398, 27299, 11593, 19918, 15924, 29430, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 19234],
[1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 413, 986, 575, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 17554, 20500, 21714, 27834, 4798, 12195, 30379, 5427, 20228, 10473, 14351, 8049, 15605, 14491, 212, 2711, 32000, 21714, 31259, 24368, 19036, 22970, 26083, 19394, 20372, 7672, 9939, 25388, 30533, 8200, 30271, 2114, 24749, 13224, 10603, 21118, 2179, 3759, 16515, 6587, 1287, 23998, 17793, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 2943, 1221, 16043, 18244, 24965, 14383, 19840, 5980, 13488, 28531, 735, 26146, 22504, 2078, 18893, 20372, 7672, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 220, 10528, 28940, 4453, 28266, 15416, 18693, 8199, 1153, 27706, 29231, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 8231, 10739, 31992, 25906, 22254, 23127, 7689, 19614, 1149, 18844, 23037, 28956, 16960, 22664, 6975, 28938, 24002, 11026, 15020, 21964, 16307], ]
@require_tf
@require_vision
class TFIdeficsModelIntegrationTest(TestCasePlus):
@cached_property
def default_processor(self):
return IdeficsProcessor.from_pretrained(IDEFICS_TINY_RANDOM_MODEL) if is_vision_available() else None
@slow
def test_inference_natural_language_visual_reasoning(self):
cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png"
cats_image_obj = Image.open(cat_image_path) # 2 cats
dogs_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg"
prompts = [
[
"User:",
dogs_image_url,
"Describe this image.\nAssistant: An image of two dogs.\n",
"User:",
cats_image_obj,
"Describe this image.\nAssistant:",
],
[
"User:",
cats_image_obj,
"Describe this image.\nAssistant: An image of two kittens.\n",
"User:",
dogs_image_url,
"Describe this image.\nAssistant:",
],
]
model = TFIdeficsForVisionText2Text.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True)
processor = self.default_processor
inputs = processor(prompts, return_tensors="tf")
generated_ids = model.generate(**inputs, max_length=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# keep for debugging
for i, t in enumerate(generated_text):
t = bytes(t, "utf-8").decode("unicode_escape")
print(f"{i}:\n{t}\n")
self.assertListEqual(EXPECTED_GENERATED_IDS[0], generated_ids[0].numpy().tolist())
self.assertListEqual(EXPECTED_GENERATED_IDS[1], generated_ids[1].numpy().tolist())

View File

@ -41,7 +41,7 @@ class IdeficsProcessorTest(TestCasePlus):
self.checkpoint_path = self.get_auto_remove_tmp_dir()
image_processor = IdeficsImageProcessor()
image_processor = IdeficsImageProcessor(return_tensors="pt")
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics")
processor = IdeficsProcessor(image_processor, tokenizer)
@ -132,7 +132,7 @@ class IdeficsProcessorTest(TestCasePlus):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
@ -145,7 +145,7 @@ class IdeficsProcessorTest(TestCasePlus):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer(padding_side="right")
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
predicted_tokens = [
"<s> Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk>",
@ -156,8 +156,9 @@ class IdeficsProcessorTest(TestCasePlus):
([1] * 10) + ([0] * 10),
]
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
longest = processor(prompts, padding="longest", truncation=True, max_length=30)
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
@ -203,7 +204,7 @@ class IdeficsProcessorTest(TestCasePlus):
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
prompts = self.prepare_prompts()
inputs = processor(prompts, padding="longest")
inputs = processor(prompts, padding="longest", return_tensors="pt")
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
self.assertSetEqual(set(inputs.keys()), set(self.input_keys))

View File

@ -380,7 +380,9 @@ class TFModelTesterMixin:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
name: keras.Input(tensor.shape[1:], dtype=tensor.dtype)
for name, tensor in inputs_dict.items()
if tf.is_tensor(tensor)
}
model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
@ -1689,7 +1691,11 @@ class TFModelTesterMixin:
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
if "labels" not in tf_inputs_dict:
return # This model isn't giving us labels after all, don't try training with it
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
if "head_mask" not in key and isinstance(val, tf.Tensor)
}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
@ -1853,8 +1859,8 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
def random_attention_mask(shape, rng=None, name=None, dtype=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
# make sure that at least one token is attended to for each batch
attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
# Mark the first token as 1 (matches behaviour of PyTorch/Flax function)
attn_mask = tf.concat([tf.ones_like(attn_mask[:, :1]), attn_mask[:, 1:]], axis=1)
return attn_mask