mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Updated Aria model card (#38472)
Some checks failed
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
New model PR merged notification / Notify new model (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
New model PR merged notification / Notify new model (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
* Update aria.md * Update aria.md * Suggested Updates - aria.md
This commit is contained in:
parent
77cf4936fe
commit
593e29c5e2
@ -14,60 +14,71 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Aria
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Aria
|
||||
|
||||
The Aria model was proposed in [Aria: An Open Multimodal Native Mixture-of-Experts Model](https://huggingface.co/papers/2410.05993) by Li et al. from the Rhymes.AI team.
|
||||
[Aria](https://huggingface.co/papers/2410.05993) is a multimodal mixture-of-experts (MoE) model. The goal of this model is to open-source a training recipe for creating a multimodal native model from scratch. Aria has 3.9B and 3.5B activated parameters per visual and text token respectively. Text is handled by a MoE decoder and visual inputs are handled by a lightweight visual encoder. It is trained in 4 stages, language pretraining, multimodal pretraining, multimodal long-context pretraining, and multimodal post-training.
|
||||
|
||||
Aria is an open multimodal-native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. It has a Mixture-of-Experts architecture, with respectively 3.9B and 3.5B activated parameters per visual token and text token.
|
||||
You can find all the original Aria checkpoints under the [Aria](https://huggingface.co/rhymes-ai?search_models=aria) organization.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
> [!TIP]
|
||||
> Click on the Aria models in the right sidebar for more examples of how to apply Aria to different multimodal tasks.
|
||||
|
||||
*Information comes in diverse modalities. Multimodal native AI models are essential to integrate real-world information and deliver comprehensive understanding. While proprietary multimodal native models exist, their lack of openness imposes obstacles for adoptions, let alone adaptations. To fill this gap, we introduce Aria, an open multimodal native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. Aria is a mixture-of-expert model with 3.9B and 3.5B activated parameters per visual token and text token, respectively. It outperforms Pixtral-12B and Llama3.2-11B, and is competitive against the best proprietary models on various multimodal tasks. We pre-train Aria from scratch following a 4-stage pipeline, which progressively equips the model with strong capabilities in language understanding, multimodal understanding, long context window, and instruction following. We open-source the model weights along with a codebase that facilitates easy adoptions and adaptations of Aria in real-world applications.*
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
This model was contributed by [m-ric](https://huggingface.co/m-ric).
|
||||
The original code can be found [here](https://github.com/rhymes-ai/Aria).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage tips
|
||||
|
||||
Here's how to use the model for vision tasks:
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
|
||||
from transformers import AriaProcessor, AriaForConditionalGeneration
|
||||
pipeline = pipeline(
|
||||
"image-to-text",
|
||||
model="rhymes-ai/Aria",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
text="What is shown in this image?"
|
||||
)
|
||||
```
|
||||
|
||||
model_id_or_path = "rhymes-ai/Aria"
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
model = AriaForConditionalGeneration.from_pretrained(
|
||||
model_id_or_path, device_map="auto"
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"rhymes-ai/Aria",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
processor = AriaProcessor.from_pretrained(model_id_or_path)
|
||||
|
||||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
processor = AutoProcessor.from_pretrained("rhymes-ai/Aria")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"text": "what is the image?", "type": "text"},
|
||||
],
|
||||
}
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
inputs = processor(text=text, images=image, return_tensors="pt")
|
||||
inputs.to(model.device)
|
||||
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
|
||||
ipnuts = inputs.to(model.device, torch.bfloat16)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
@ -79,6 +90,55 @@ output = model.generate(
|
||||
)
|
||||
output_ids = output[0][inputs["input_ids"].shape[1]:]
|
||||
response = processor.decode(output_ids, skip_special_tokens=True)
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4 and the [rhymes-ai/Aria-sequential_mlp](https://huggingface.co/rhymes-ai/Aria-sequential_mlp) checkpoint. This checkpoint replaces grouped GEMM with `torch.nn.Linear` layers for easier quantization.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"rhymes-ai/Aria-sequential_mlp",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"rhymes-ai/Aria-sequential_mlp",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
|
||||
inputs = inputs.to(model.device, torch.bfloat16)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=15,
|
||||
stop_strings=["<|im_end|>"],
|
||||
tokenizer=processor.tokenizer,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
)
|
||||
output_ids = output[0][inputs["input_ids"].shape[1]:]
|
||||
response = processor.decode(output_ids, skip_special_tokens=True)
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user