* Add Aria
---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Aymeric Roucher 2024-12-06 12:17:34 +01:00 committed by GitHub
parent 15ab310c3a
commit 9ad4c93536
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 6244 additions and 7 deletions

View File

@ -810,6 +810,8 @@
title: ALIGN
- local: model_doc/altclip
title: AltCLIP
- local: model_doc/aria
title: Aria
- local: model_doc/blip
title: BLIP
- local: model_doc/blip-2

View File

@ -62,6 +62,8 @@ Flax), PyTorch, and/or TensorFlow.
| [ALBERT](model_doc/albert) | ✅ | ✅ | ✅ |
| [ALIGN](model_doc/align) | ✅ | ❌ | ❌ |
| [AltCLIP](model_doc/altclip) | ✅ | ❌ | ❌ |
| [Aria](model_doc/aria) | ✅ | ❌ | ❌ |
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
@ -172,6 +174,7 @@ Flax), PyTorch, and/or TensorFlow.
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |
| [Idefics3VisionTransformer](model_doc/idefics3_vision) | ❌ | ❌ | ❌ |
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,106 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Aria
## Overview
The Aria model was proposed in [Aria: An Open Multimodal Native Mixture-of-Experts Model](https://huggingface.co/papers/2410.05993) by Li et al. from the Rhymes.AI team.
Aria is an open multimodal-native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. It has a Mixture-of-Experts architecture, with respectively 3.9B and 3.5B activated parameters per visual token and text token.
The abstract from the paper is the following:
*Information comes in diverse modalities. Multimodal native AI models are essential to integrate real-world information and deliver comprehensive understanding. While proprietary multimodal native models exist, their lack of openness imposes obstacles for adoptions, let alone adaptations. To fill this gap, we introduce Aria, an open multimodal native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. Aria is a mixture-of-expert model with 3.9B and 3.5B activated parameters per visual token and text token, respectively. It outperforms Pixtral-12B and Llama3.2-11B, and is competitive against the best proprietary models on various multimodal tasks. We pre-train Aria from scratch following a 4-stage pipeline, which progressively equips the model with strong capabilities in language understanding, multimodal understanding, long context window, and instruction following. We open-source the model weights along with a codebase that facilitates easy adoptions and adaptations of Aria in real-world applications.*
This model was contributed by [m-ric](https://huggingface.co/m-ric).
The original code can be found [here](https://github.com/rhymes-ai/Aria).
## Usage tips
Here's how to use the model for vision tasks:
```python
import requests
import torch
from PIL import Image
from transformers import AriaProcessor, AriaForConditionalGeneration
model_id_or_path = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto"
)
processor = AriaProcessor.from_pretrained(model_id_or_path)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs.to(model.device)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
```
## AriaImageProcessor
[[autodoc]] AriaImageProcessor
## AriaProcessor
[[autodoc]] AriaProcessor
## AriaTextConfig
[[autodoc]] AriaTextConfig
## AriaConfig
[[autodoc]] AriaConfig
## AriaTextModel
[[autodoc]] AriaTextModel
## AriaTextForCausalLM
[[autodoc]] AriaTextForCausalLM
## AriaForConditionalGeneration
[[autodoc]] AriaForConditionalGeneration
- forward

View File

@ -51,6 +51,13 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
[[autodoc]] Idefics3Config
## Idefics3VisionConfig
[[autodoc]] Idefics3VisionConfig
## Idefics3VisionTransformer
[[autodoc]] Idefics3VisionTransformer
## Idefics3Model

View File

@ -37,6 +37,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
FlashAttention-2 is currently supported for the following architectures:
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
@ -216,6 +217,7 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o
For now, Transformers supports SDPA inference and training for the following architectures:
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)

View File

@ -170,6 +170,11 @@ _import_structure = {
"AltCLIPTextConfig",
"AltCLIPVisionConfig",
],
"models.aria": [
"AriaConfig",
"AriaProcessor",
"AriaTextConfig",
],
"models.audio_spectrogram_transformer": [
"ASTConfig",
"ASTFeatureExtractor",
@ -1176,6 +1181,7 @@ else:
_import_structure["image_processing_base"] = ["ImageProcessingMixin"]
_import_structure["image_processing_utils"] = ["BaseImageProcessor"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.aria"].extend(["AriaImageProcessor"])
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
_import_structure["models.blip"].extend(["BlipImageProcessor"])
@ -1406,6 +1412,15 @@ else:
"AltCLIPVisionModel",
]
)
_import_structure["models.aria"].extend(
[
"AriaForConditionalGeneration",
"AriaPreTrainedModel",
"AriaTextForCausalLM",
"AriaTextModel",
"AriaTextPreTrainedModel",
]
)
_import_structure["models.audio_spectrogram_transformer"].extend(
[
"ASTForAudioClassification",
@ -2461,6 +2476,8 @@ else:
"Idefics3Model",
"Idefics3PreTrainedModel",
"Idefics3Processor",
"Idefics3VisionConfig",
"Idefics3VisionTransformer",
]
)
_import_structure["models.ijepa"].extend(
@ -5033,6 +5050,11 @@ if TYPE_CHECKING:
AltCLIPTextConfig,
AltCLIPVisionConfig,
)
from .models.aria import (
AriaConfig,
AriaProcessor,
AriaTextConfig,
)
from .models.audio_spectrogram_transformer import (
ASTConfig,
ASTFeatureExtractor,
@ -6096,6 +6118,7 @@ if TYPE_CHECKING:
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
from .image_utils import ImageFeatureExtractionMixin
from .models.aria import AriaImageProcessor
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
from .models.blip import BlipImageProcessor
@ -6325,6 +6348,13 @@ if TYPE_CHECKING:
AltCLIPTextModel,
AltCLIPVisionModel,
)
from .models.aria import (
AriaForConditionalGeneration,
AriaPreTrainedModel,
AriaTextForCausalLM,
AriaTextModel,
AriaTextPreTrainedModel,
)
from .models.audio_spectrogram_transformer import (
ASTForAudioClassification,
ASTModel,
@ -7189,6 +7219,8 @@ if TYPE_CHECKING:
Idefics3Model,
Idefics3PreTrainedModel,
Idefics3Processor,
Idefics3VisionConfig,
Idefics3VisionTransformer,
)
from .models.ijepa import (
IJepaForImageClassification,

View File

@ -1465,6 +1465,7 @@ class GenerationMixin:
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and input_ids_length != 0
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]

View File

@ -16,6 +16,7 @@ from . import (
albert,
align,
altclip,
aria,
audio_spectrogram_transformer,
auto,
autoformer,

View File

@ -0,0 +1,30 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_aria import *
from .image_processing_aria import *
from .modeling_aria import *
from .processing_aria import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,299 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_aria.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ..auto import CONFIG_MAPPING, AutoConfig
class AriaTextConfig(PretrainedConfig):
r"""
This class handles the configuration for the text component of the Aria model.
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 4096):
The size of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 2):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_heads
moe_num_experts (`int`, *optional*, defaults to 8):
The number of experts in the MoE layer.
moe_topk (`int`, *optional*, defaults to 2):
The number of top experts to route to for each token.
moe_num_shared_experts (`int`, *optional*, defaults to 2):
The number of shared experts.
"""
model_type = "aria_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `AriaTextModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_config_key = "text_config"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size: int = 4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=2,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
moe_num_experts: int = 8,
moe_topk: int = 2,
moe_num_shared_experts: int = 2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts
class AriaConfig(PretrainedConfig):
r"""
This class handles the configuration for both vision and text components of the Aria model,
as well as additional parameters for image token handling and projector mapping.
Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
[rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`AriaVisionConfig` or `dict`, *optional*):
Configuration for the vision component.
vision_feature_layer (`int`, *optional*, defaults to -1):
The index of the layer to select the vision feature.
text_config (`AriaTextConfig` or `dict`, *optional*):
Configuration for the text component.
projector_patch_to_query_dict (`dict`, *optional*):
Mapping of patch sizes to query dimensions.
image_token_index (`int`, *optional*, defaults to 9):
Index used to represent image tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal initializer for initializing all weight matrices.
Attributes:
model_type (`str`):
Type of the model, set to `"aria"`.
image_token_index (`int`):
Index used to represent image tokens.
projector_patch_to_query_dict (`dict`):
Mapping of patch sizes to query dimensions.
vision_config (`AriaVisionConfig`):
Configuration for the vision component.
text_config (`AriaTextConfig`):
Configuration for the text component.
"""
model_type = "aria"
sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
def __init__(
self,
vision_config=None,
vision_feature_layer: int = -1,
text_config: AriaTextConfig = None,
projector_patch_to_query_dict: Dict = None,
image_token_index: int = 9,
initializer_range: float = 0.02,
**kwargs,
):
self.image_token_index = image_token_index
# Convert the keys and values of projector_patch_to_query_dict to integers
# This ensures consistency even if they were provided as strings
if projector_patch_to_query_dict is None:
projector_patch_to_query_dict = {
1225: 128,
4900: 256,
}
self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
self.vision_feature_layer = vision_feature_layer
if isinstance(vision_config, dict):
vision_config["model_type"] = "idefics3_vision"
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["idefics3_vision"]()
self.vision_config = vision_config
self.initializer_range = initializer_range
if isinstance(text_config, dict) and "model_type" in text_config:
text_config = AriaTextConfig(**text_config)
elif text_config is None:
text_config = AriaTextConfig()
self.text_config = text_config
super().__init__(**kwargs)
__all__ = ["AriaConfig", "AriaTextConfig"]

View File

@ -0,0 +1,162 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import (
AddedToken,
AriaForConditionalGeneration,
AriaProcessor,
AutoConfig,
AutoTokenizer,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
Example for creating the old state dict file with Python:
import torch
from aria.model.language_model.aria_llama import AriaTextForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"vision_tower.vision_model": "vision_tower",
"ln_ffn": "layer_norm",
"ffn": "feed_forward",
"ln_kv": "layer_norm_kv",
}
def load_original_state_dict(model_id):
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
return original_state_dict
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
return new_state_dict
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
extra_special_tokens={
"image_token": "<|img|>",
"pad_token": "<pad>",
},
)
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
processor = AriaProcessor.from_pretrained(
text_model_id,
tokenizer=tokenizer,
)
config = AutoConfig.from_pretrained(text_model_id)
config.vision_config.hidden_size = 1152
config.vision_config.attention_heads = 16
config.pad_token_id = 2
config.image_token_index = 9
config.intermediate_size = config.moe_intermediate_size
config.auto_map = {
"AutoConfig": "modeling_aria.AriaConfig",
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
}
with torch.device("meta"):
model = AriaForConditionalGeneration(config)
state_dict = load_original_state_dict(old_state_dict_id)
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)
# print("Saving models")
# model.save_pretrained("local_aria", safe_serialization=False)
# processor.save_pretrained("local_aria")
print("Pushing to hub")
model.push_to_hub(output_hub_path, create_pr=True)
processor.push_to_hub(output_hub_path, create_pr=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
default="rhymes-ai/Aria",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
default="rhymes-ai/Aria",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
default="rhymes-ai/Aria",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
default="rhymes-ai/Aria",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,504 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_aria.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_valid_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images
elif is_valid_image(images):
return [images]
raise ValueError(f"Could not make batched video from {images}")
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
"""
Divides an image into patches of a specified size.
Args:
image (`np.array`):
The input image.
patch_size (`int`):
The size of each patch.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
list: A list of np.array representing the patches.
"""
patches = []
height, width = get_image_size(image, channel_dim=input_data_format)
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
if input_data_format == ChannelDimension.LAST:
patch = image[i : i + patch_size, j : j + patch_size]
else:
patch = image[:, i : i + patch_size, j : j + patch_size]
patches.append(patch)
return patches
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
Initialize the AriaImageProcessor.
Args:
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
Mean values for normalization.
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
Standard deviation values for normalization.
max_image_size (`int`, *optional*, defaults to 980):
Maximum image size.
min_image_size (`int`, *optional*, defaults to 336):
Minimum image size.
split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
The optimal resolutions for splitting the image.
split_image (`bool`, *optional*, defaults to `False`):
Whether to split the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image.
resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
The resampling filter to use if resizing the image.
"""
def __init__(
self,
image_mean: List[float] = None,
image_std: List[float] = None,
max_image_size: int = 980,
min_image_size: int = 336,
split_resolutions: Optional[List[Tuple[int, int]]] = None,
split_image: Optional[bool] = False,
do_convert_rgb: Optional[bool] = True,
do_normalize: Optional[bool] = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
**kwargs,
):
super().__init__(**kwargs)
if image_mean is None:
image_mean = [0.5, 0.5, 0.5]
if image_std is None:
image_std = [0.5, 0.5, 0.5]
self.max_image_size = max_image_size
self.min_image_size = min_image_size
self.image_mean = image_mean
self.image_std = image_std
self.split_image = split_image
if split_resolutions is None:
split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
self.split_resolutions = split_resolutions
self.do_convert_rgb = do_convert_rgb
self.do_normalize = do_normalize
self.resample = resample
def preprocess(
self,
images: Union[ImageInput, List[ImageInput]],
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
max_image_size: Optional[int] = None,
min_image_size: Optional[int] = None,
split_image: Optional[bool] = None,
do_convert_rgb: Optional[bool] = None,
do_normalize: Optional[bool] = None,
resample: PILImageResampling = None,
return_tensors: Optional[Union[str, TensorType]] = "pt",
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Process a list of images.
Args:
images (ImageInput or list of ImageInput):
The input image or a list of images.
image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
Mean values for normalization.
image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
Standard deviation values for normalization.
max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
Maximum image size.
min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
Minimum image size.
split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
Whether to split the image.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
Whether to convert the image to RGB.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
Whether to normalize the image.
resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
The resampling filter to use if resizing the image.
return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
The type of tensor to return.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`:
image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`:
image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`:
image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`:
image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
BatchFeature:
A BatchFeature object containing:
- 'pixel_values':
Tensor of processed image pixel values.
- 'pixel_mask':
Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
- True (1) values indicate pixels that belong to the original resized image.
- False (0) values indicate pixels that are part of the padding.
The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
- 'num_crops':
The maximum number of crops across all images.
"""
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
max_image_size = max_image_size if max_image_size is not None else self.max_image_size
min_image_size = min_image_size if min_image_size is not None else self.min_image_size
split_image = split_image if split_image is not None else self.split_image
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
if max_image_size not in [490, 980]:
raise ValueError("max_image_size must be either 490 or 980")
images = make_batched_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
resample=resample,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
pixel_values = []
pixel_masks = []
num_crops = None
for image in images:
if split_image:
crop_images = self.get_image_patches(
image,
self.split_resolutions,
max_image_size,
resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
else:
crop_images = [image]
if num_crops is None or len(crop_images) > num_crops:
num_crops = len(crop_images)
for crop_image in crop_images:
# At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
h, w = get_image_size(crop_image)
scale = max_image_size / max(h, w)
if w >= h:
new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
else:
new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
crop_image_resized = resize(
crop_image,
new_size,
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
crop_image_padded = pad(
crop_image_resized,
((0, padding_bottom), (0, padding_right)),
data_format=input_data_format,
input_data_format=input_data_format,
)
# Create a pixel mask
pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
pixel_mask[: new_size[0], : new_size[1]] = 1
pixel_masks.append(pixel_mask)
if do_normalize:
crop_image_padded = self.normalize(
crop_image_padded / 255.0,
self.image_mean,
self.image_std,
data_format=input_data_format,
input_data_format=input_data_format,
)
crop_image_padded = (
to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
if data_format is not None
else crop_image_padded
)
pixel_values.append(crop_image_padded)
return BatchFeature(
data={
"pixel_values": np.stack(pixel_values, axis=0),
"pixel_mask": np.stack(pixel_masks, axis=0),
"num_crops": num_crops,
},
tensor_type=return_tensors,
)
def _resize_for_patching(
self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension
) -> np.array:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image (np.array):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
resample (`PILImageResampling`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
return resized_image
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
return padded_image
def pad(
self,
image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
as input.
Args:
image (`np.ndarray`):
The image to pad.
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
# call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
if isinstance(padding, int) or len(padding) != 4:
return pad(image, padding, mode, constant_values, data_format, input_data_format)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
padding_mode_mapping = {
PaddingMode.CONSTANT: "constant",
PaddingMode.REFLECT: "reflect",
PaddingMode.REPLICATE: "edge",
PaddingMode.SYMMETRIC: "symmetric",
}
image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
image = (
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
)
return image
def get_image_patches(
self,
image: np.array,
grid_pinpoints: List[Tuple[int, int]],
patch_size: int,
resample: PILImageResampling,
data_format: ChannelDimension,
input_data_format: ChannelDimension,
) -> List[np.array]:
"""
Process an image with variable resolutions by dividing it into patches.
Args:
image (`np.array`):
The input image to be processed.
grid_pinpoints (List[Tuple[int, int]]):
A list of possible resolutions as tuples.
patch_size (`int`):
Size of the patches to divide the image into.
resample (`PILImageResampling`):
Resampling filter to use if resizing the image.
data_format (`ChannelDimension` or `str`):
The channel dimension format for the output image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
`List[np.array]`: A list of NumPy arrays containing the processed image patches.
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
possible_resolutions = grid_pinpoints
image_size = get_image_size(image, channel_dim=input_data_format)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_patching(
image, best_resolution, resample=resample, input_data_format=input_data_format
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
# make sure that all patches are in the input data format
patches = [
to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
for patch in patches
]
return patches
__all__ = ["AriaImageProcessor"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,164 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/aria/modular_aria.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_aria.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import PreTokenizedInput, TextInput
from ...utils import TensorType
from ..auto import AutoTokenizer
class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"max_image_size": 980,
"split_image": False,
},
"return_tensors": TensorType.PYTORCH,
}
class AriaProcessor(ProcessorMixin):
"""
AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
Args:
image_processor (`AriaImageProcessor`, *optional*):
The AriaImageProcessor to use for image preprocessing.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
chat_template (`str`, *optional*):
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
size_conversion (`Dict`, *optional*):
A dictionary indicating size conversions for images.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template", "size_conversion"]
image_processor_class = "AriaImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer: Union[AutoTokenizer, str] = None,
chat_template: Optional[str] = None,
size_conversion: Optional[Dict[Union[float, int], int]] = None,
):
if size_conversion is None:
size_conversion = {490: 128, 980: 256}
self.size_conversion = {int(k): v for k, v in size_conversion.items()}
if tokenizer is not None and tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: Optional[ImageInput] = None,
audio=None,
videos=None,
**kwargs: Unpack[AriaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s).
Args:
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
AriaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if images is not None:
image_inputs = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
# expand the image_token according to the num_crops and tokens per image
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
prompt_strings = []
num_crops = image_inputs.pop("num_crops") * tokens_per_image
for sample in text:
sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
prompt_strings.append(sample)
else:
image_inputs = {}
prompt_strings = text
text_inputs = self.tokenizer(
prompt_strings,
**output_kwargs["text_kwargs"],
)
return BatchFeature(data={**text_inputs, **image_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["AriaProcessor"]

View File

@ -35,6 +35,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("albert", "AlbertConfig"),
("align", "AlignConfig"),
("altclip", "AltCLIPConfig"),
("aria", "AriaConfig"),
("aria_text", "AriaTextConfig"),
("audio-spectrogram-transformer", "ASTConfig"),
("autoformer", "AutoformerConfig"),
("bark", "BarkConfig"),
@ -135,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("idefics3", "Idefics3Config"),
("idefics3_vision", "Idefics3VisionConfig"),
("ijepa", "IJepaConfig"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
@ -327,6 +330,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("albert", "ALBERT"),
("align", "ALIGN"),
("altclip", "AltCLIP"),
("aria", "Aria"),
("aria_text", "AriaText"),
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
("autoformer", "Autoformer"),
("bark", "Bark"),
@ -441,6 +446,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("idefics3", "Idefics3"),
("idefics3_vision", "Idefics3VisionTransformer"),
("ijepa", "I-JEPA"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
@ -687,6 +693,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("clip_vision_model", "clip"),
("qwen2_audio_encoder", "qwen2_audio"),
("clip_text_model", "clip"),
("aria_text", "aria"),
("idefics3_vision", "idefics3"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),

View File

@ -54,6 +54,7 @@ else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("align", ("EfficientNetImageProcessor",)),
("aria", ("AriaImageProcessor")),
("beit", ("BeitImageProcessor",)),
("bit", ("BitImageProcessor",)),
("blip", ("BlipImageProcessor",)),

View File

@ -35,6 +35,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
("aria", "AriaForConditionalGeneration"),
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
("bark", "BarkModel"),
@ -132,6 +134,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("idefics3", "Idefics3Model"),
("idefics3_vision", "Idefics3VisionTransformer"),
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
@ -464,6 +467,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("aria_text", "AriaTextForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
@ -768,6 +772,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[
("aria", "AriaForConditionalGeneration"),
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("chameleon", "ChameleonForConditionalGeneration"),

View File

@ -47,6 +47,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("align", "AlignProcessor"),
("altclip", "AltCLIPProcessor"),
("aria", "AriaProcessor"),
("bark", "BarkProcessor"),
("blip", "BlipProcessor"),
("blip-2", "Blip2Processor"),

View File

@ -68,6 +68,7 @@ else:
),
),
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
(

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_idefics3": ["Idefics3Config"]}
_import_structure = {"configuration_idefics3": ["Idefics3Config", "Idefics3VisionConfig"]}
try:
@ -38,11 +38,12 @@ else:
"Idefics3ForConditionalGeneration",
"Idefics3PreTrainedModel",
"Idefics3Model",
"Idefics3VisionTransformer",
]
_import_structure["processing_idefics3"] = ["Idefics3Processor"]
if TYPE_CHECKING:
from .configuration_idefics3 import Idefics3Config
from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
try:
if not is_vision_available():
@ -62,6 +63,7 @@ if TYPE_CHECKING:
Idefics3ForConditionalGeneration,
Idefics3Model,
Idefics3PreTrainedModel,
Idefics3VisionTransformer,
)
from .processing_idefics3 import Idefics3Processor

View File

@ -685,6 +685,41 @@ class AltCLIPVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AriaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AriaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AriaTextForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AriaTextModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AriaTextPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ASTForAudioClassification(metaclass=DummyObject):
_backends = ["torch"]
@ -4978,6 +5013,20 @@ class Idefics3Processor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Idefics3VisionConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics3VisionTransformer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class IJepaForImageClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -23,6 +23,13 @@ class ImageFeatureExtractionMixin(metaclass=DummyObject):
requires_backends(self, ["vision"])
class AriaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class BeitFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]

View File

@ -1727,6 +1727,7 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers
inputs_embeds = model.get_input_embeddings()(input_ids)
max_cache_len += inputs_embeds.shape[1]
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
# we should get `max_length` in shape, not `max_length - embeds_length`

View File

View File

@ -0,0 +1,268 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.image_utils import PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
if is_vision_available():
from PIL import Image
from transformers import AriaImageProcessor
if is_torch_available():
import torch
class AriaImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
num_images=1,
min_resolution=30,
max_resolution=40,
size=None,
max_image_size=980,
min_image_size=336,
split_resolutions=None,
split_image=True,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_convert_rgb=True,
resample=PILImageResampling.BICUBIC,
):
super().__init__()
self.size = size if size is not None else {"longest_edge": max_resolution}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.num_images = num_images
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.resample = resample
self.max_image_size = max_image_size
self.min_image_size = min_image_size
self.split_resolutions = split_resolutions if split_resolutions is not None else [[980, 980]]
self.split_image = split_image
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"max_image_size": self.max_image_size,
"min_image_size": self.min_image_size,
"split_resolutions": self.split_resolutions,
"split_image": self.split_image,
"do_convert_rgb": self.do_convert_rgb,
"do_normalize": self.do_normalize,
"resample": self.resample,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to AriaImageProcessor,
assuming do_resize is set to True. The expected size in that case the max image size.
"""
return self.max_image_size, self.max_image_size
def expected_output_image_shape(self, images):
height, width = self.get_expected_values(images, batched=True)
return self.num_channels, height, width
def prepare_image_inputs(
self,
batch_size=None,
min_resolution=None,
max_resolution=None,
num_channels=None,
num_images=None,
size_divisor=None,
equal_resolution=False,
numpify=False,
torchify=False,
):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
One can specify whether the images are of the same resolution or not.
"""
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
batch_size = batch_size if batch_size is not None else self.batch_size
min_resolution = min_resolution if min_resolution is not None else self.min_resolution
max_resolution = max_resolution if max_resolution is not None else self.max_resolution
num_channels = num_channels if num_channels is not None else self.num_channels
num_images = num_images if num_images is not None else self.num_images
images_list = []
for i in range(batch_size):
images = []
for j in range(num_images):
if equal_resolution:
width = height = max_resolution
else:
# To avoid getting image width/height 0
if size_divisor is not None:
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
min_resolution = max(size_divisor, min_resolution)
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
images_list.append(images)
if not numpify and not torchify:
# PIL expects the channel dimension as last dimension
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
if torchify:
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
if numpify:
# Numpy images are typically in channels last format
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
return images_list
@require_torch
@require_vision
class AriaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = AriaImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = AriaImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "min_image_size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "split_image"))
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_numpy_4_channels(self):
# Aria always processes images as RGB, so it always returns images with 3 channels
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor_dict = self.image_processor_dict
image_processing = self.image_processing_class(**image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)

View File

@ -0,0 +1,669 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Aria model."""
import gc
import unittest
import requests
from transformers import (
AriaConfig,
AriaForConditionalGeneration,
AriaTextConfig,
AutoProcessor,
AutoTokenizer,
is_torch_available,
is_vision_available,
)
from transformers.models.idefics3 import Idefics3VisionConfig
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
class AriaVisionText2TextModelTester:
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=9,
projector_hidden_act="gelu",
seq_length=7,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
text_config=AriaTextConfig(
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=1,
hidden_size=32,
intermediate_size=64,
max_position_embeddings=60,
model_type="aria_moe_lm",
moe_intermediate_size=4,
moe_num_experts=4,
moe_topk=2,
num_attention_heads=20,
num_experts_per_tok=3,
num_hidden_layers=2,
num_key_value_heads=20,
rope_theta=5000000,
vocab_size=99,
eos_token_id=2,
head_dim=2,
),
is_training=True,
vision_config=Idefics3VisionConfig(
image_size=358,
patch_size=10,
num_channels=3,
is_training=True,
hidden_size=32,
projection_dim=20,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=10,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
),
):
self.parent = parent
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.text_config = text_config
self.vision_config = vision_config
self.pad_token_id = text_config.pad_token_id
self.eos_token_id = text_config.eos_token_id
self.num_hidden_layers = text_config.num_hidden_layers
self.vocab_size = text_config.vocab_size
self.hidden_size = text_config.hidden_size
self.num_attention_heads = text_config.num_attention_heads
self.is_training = is_training
self.batch_size = 10
self.num_channels = 3
self.image_size = 358
self.num_image_tokens = 128
self.seq_length = seq_length + self.num_image_tokens
def get_config(self):
return AriaConfig(
text_config=self.text_config,
vision_config=self.vision_config,
ignore_index=self.ignore_index,
image_token_index=self.image_token_index,
projector_hidden_act=self.projector_hidden_act,
vision_feature_select_strategy=self.vision_feature_select_strategy,
vision_feature_layer=self.vision_feature_layer,
eos_token_id=self.eos_token_id,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config.num_channels,
self.vision_config.image_size,
self.vision_config.image_size,
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
def create_and_check_aria_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
model = AriaForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `AriaForConditionalGeneration`.
"""
all_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = AriaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_0(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_1(self):
pass
@unittest.skip(reason="")
def test_new_cache_format_2(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Unstable test")
def test_initialization(self):
pass
@unittest.skip(reason="Unstable test")
def test_dola_decoding_sample(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_with_static_cache(self):
pass
@require_torch
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("rhymes-ai/Aria")
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://aria-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_single(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://aria-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
"USER: <image>\nWhat is this? ASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
# The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!.
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = [
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.',
'USER: \nWhat is this?\nASSISTANT: Cats'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched_regression(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "rhymes-ai/Aria"
# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True, attn_implementation="eager")
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_torch
@require_vision
def test_batched_generation(self):
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
processor = AutoProcessor.from_pretrained("rhymes-ai/Aria")
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
image1 = Image.open(requests.get(url1, stream=True).raw)
image2 = Image.open(requests.get(url2, stream=True).raw)
inputs = processor(
images=[image1, image2, image1, image2],
text=[prompt1, prompt2, prompt3],
return_tensors="pt",
padding=True,
).to(torch_device)
model = model.eval()
EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while",
"\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the",
]
generate_ids = model.generate(**inputs, max_new_tokens=20)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(outputs, EXPECTED_OUTPUT)
@slow
@require_bitsandbytes
def test_aria_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
# more details
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
# Simulate a super long prompt
user_prompt = "Describe the image:?\n" * 200
prompt = f"USER: <image>\n{user_prompt}ASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_torch_gpu
def test_aria_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
# Simulate some user inputs
pixel_values = torch.randn(
(1, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()
def test_tokenizer_integration(self):
model_id = "rhymes-ai/Aria"
slow_tokenizer = AutoTokenizer.from_pretrained(
model_id, bos_token="<|startoftext|>", eos_token="<|endoftext|>", use_fast=False
)
slow_tokenizer.add_tokens("<image>", True)
fast_tokenizer = AutoTokenizer.from_pretrained(
model_id,
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
from_slow=True,
legacy=False,
)
fast_tokenizer.add_tokens("<image>", True)
prompt = "<|startoftext|><|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>"
EXPECTED_OUTPUT = ['<|startoftext|>', '<', '|', 'im', '_', 'start', '|', '>', 'system', '\n', 'Answer', '▁the', '▁questions', '.<', '|', 'im', '_', 'end', '|', '><', '|', 'im', '_', 'start', '|', '>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<', '|', 'im', '_', 'end', '|', '>'] # fmt: skip
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
@slow
@require_bitsandbytes
def test_generation_no_images(self):
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
# Prepare inputs with no images
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_bitsandbytes
def test_generation_siglip_backbone(self):
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map=torch_device)
processor = AutoProcessor.from_pretrained(model_id)
# check processing with expansion of inputs (w/o expansion should work with any backbone)
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(
text="<|im_start|>user\n<image>\nWhat are these?<|im_end|>\n<|im_start|>assistant",
images=raw_image,
return_tensors="pt",
).to(torch_device, torch.float16)
# Make sure that `generate` works
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat"
self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nDescribe the image:\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 18)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_pixtral(self):
model_id = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
IMG_URLS = [
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw),
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
# image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# fmt: off
EXPECTED_GENERATION = """
Describe the images.
Sure, let's break down each image description:
1. **Image 1:**
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:**
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:**
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:**
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
"""
# fmt: on
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(ouptut, EXPECTED_GENERATION)

View File

@ -0,0 +1,391 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from io import BytesIO
from typing import Optional
import numpy as np
import requests
from transformers import AriaProcessor
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from PIL import Image
@require_torch
@require_vision
class AriaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = AriaProcessor
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
processor = AriaProcessor.from_pretrained("m-ric/Aria_hf_2", image_seq_len=2)
processor.save_pretrained(cls.tmpdirname)
cls.image1 = Image.open(
BytesIO(
requests.get(
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
).content
)
)
cls.image2 = Image.open(
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
)
cls.image3 = Image.open(
BytesIO(
requests.get(
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
).content
)
)
cls.bos_token = "<|im_start|>"
cls.eos_token = "<|im_end|>"
cls.image_token = processor.tokenizer.image_token
cls.fake_image_token = "o"
cls.global_img_token = "<|img|>"
cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token)
cls.eos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.eos_token)
cls.image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.image_token)
cls.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.fake_image_token)
cls.global_img_tokens_id = processor.tokenizer(cls.global_img_token, add_special_tokens=False)["input_ids"]
cls.padding_token_id = processor.tokenizer.pad_token_id
cls.image_seq_len = 256
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def get_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname)
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component(
"image_processor", do_rescale=True, rescale_factor=1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_process_interleaved_images_prompts_image_splitting(self):
processor = self.get_processor()
processor.image_processor.split_image = True
# Test that a single image is processed correctly
inputs = processor(images=self.image1, text="Ok<|img|>", images_kwargs={"split_image": True})
self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 3, 980, 980))
self.assertEqual(np.array(inputs["pixel_mask"]).shape, (2, 980, 980))
def test_process_interleaved_images_prompts_no_image_splitting(self):
processor = self.get_processor()
processor.image_processor.split_image = False
# Test that a single image is processed correctly
inputs = processor(images=self.image1, text="Ok<|img|>")
image1_expected_size = (980, 980)
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 3, *image1_expected_size))
self.assertEqual(np.array(inputs["pixel_mask"]).shape, (1, *image1_expected_size))
# fmt: on
# Test a single sample with image and text
image_str = "<|img|>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = processor(text=text, images=self.image1)
# fmt: off
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.image_token_id] * self.image_seq_len + tokenized_sentence["input_ids"]]
# self.assertEqual(len(inputs["input_ids"]), len(expected_input_ids))
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 3, *image1_expected_size))
self.assertEqual(np.array(inputs["pixel_mask"]).shape, (1, *image1_expected_size))
# fmt: on
# Test that batch is correctly processed
image_str = "<|img|>"
text_str_1 = "In this image, we see"
text_str_2 = "In this image, we see"
text = [
image_str + text_str_1,
image_str + image_str + text_str_2,
]
images = [[self.image1], [self.image2, self.image3]]
inputs = processor(text=text, images=images, padding=True)
# fmt: off
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
image_tokens = [self.image_token_id] * self.image_seq_len
expected_input_ids_1 = image_tokens + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = 2 * image_tokens + tokenized_sentence_2["input_ids"]
# Pad the first input to match the second input
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
expected_attention_mask = [[0] * pad_len + [1] * len(expected_input_ids_1), [1] * (len(expected_input_ids_2))]
self.assertEqual(
inputs["attention_mask"],
expected_attention_mask
)
self.assertEqual(np.array(inputs['pixel_values']).shape, (3, 3, 980, 980))
self.assertEqual(np.array(inputs['pixel_mask']).shape, (3, 980, 980))
# fmt: on
def test_non_nested_images_with_batched_text(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False
image_str = "<|img|>"
text_str_1 = "In this image, we see"
text_str_2 = "In this image, we see"
text = [
image_str + text_str_1,
image_str + image_str + text_str_2,
]
images = [self.image1, self.image2, self.image3]
inputs = processor(text=text, images=images, padding=True)
self.assertEqual(np.array(inputs["pixel_values"]).shape, (3, 3, 980, 980))
self.assertEqual(np.array(inputs["pixel_mask"]).shape, (3, 980, 980))
def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What do these images show?"},
{"type": "image"},
{"type": "image"},
"What do these images show?",
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.",
}
],
},
{"role": "user", "content": [{"type": "text", "text": "And who is that?"}]},
]
processor = self.get_processor()
# Make short sequence length to test that the fake tokens are added correctly
rendered = processor.apply_chat_template(messages, add_generation_prompt=True)
print(rendered)
expected_rendered = """<|im_start|>user
What do these images show?<fim_prefix><|img|><fim_suffix><fim_prefix><|img|><fim_suffix><|im_end|>
<|im_start|>assistant
The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.<|im_end|>
<|im_start|>user
And who is that?<|im_end|>
<|im_start|>assistant
"""
self.assertEqual(rendered, expected_rendered)
# Override as AriaProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <|img|>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <|img|>"]
return ["lower newer <|img|>", "<|img|> upper older longer string"] + ["<|img|> lower newer"] * (
batch_size - 2
)
# Override tests as inputs_ids padded dimension is the second one but not the last one
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=30)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=30)
self.assertEqual(len(inputs["input_ids"][0]), 30)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
inputs = processor(
text=input_str,
images=image_input,
common_kwargs={"return_tensors": "pt"},
images_kwargs={"max_image_size": 980},
text_kwargs={"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[3], 980)
self.assertEqual(len(inputs["input_ids"][0]), 120)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"max_image_size": 980},
"text_kwargs": {"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[3], 980)
self.assertEqual(len(inputs["input_ids"][0]), 120)
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=30)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 30)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs(batch_size=2)
image_input = self.prepare_image_inputs(batch_size=2)
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="longest",
max_length=76,
truncation=True,
max_image_size=980,
)
self.assertEqual(inputs["pixel_values"].shape[1], 3)
self.assertEqual(inputs["pixel_values"].shape[3], 980)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_image_size=980,
padding="max_length",
max_length=120,
truncation="longest_first",
)
self.assertEqual(inputs["pixel_values"].shape[3], 980)
self.assertEqual(len(inputs["input_ids"][0]), 120)

View File

@ -865,6 +865,7 @@ def match_docstring_with_signature(obj: Any) -> Optional[Tuple[str, str]]:
# We went too far by one (perhaps more if there are a lot of new lines)
idx -= 1
if current_arg:
while len(obj_doc_lines[idx].strip()) == 0:
arguments[current_arg] = arguments[current_arg][:-1]
idx -= 1

View File

@ -85,6 +85,8 @@ PRIVATE_MODELS = [
"Idefics2PerceiverResampler",
"Idefics2VisionTransformer",
"Idefics3VisionTransformer",
"AriaTextForCausalLM",
"AriaTextModel",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.

View File

@ -1678,7 +1678,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"],
default=["src/transformers/models/aria/modular_aria.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)