# Chameleon
## Overview
The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.
The abstract from the paper is the following:
*We present Chameleon, a family of early-fusion token-based mixed-modal models capable of understanding and generating images and text in any arbitrary sequence. We outline a stable training
approach from inception, an alignment recipe, and an architectural parameterization tailored for the
early-fusion, token-based, mixed-modal setting. The models are evaluated on a comprehensive range
of tasks, including visual question answering, image captioning, text generation, image generation, and
long-form mixed modal generation. Chameleon demonstrates broad and general capabilities, including
state-of-the-art performance in image captioning tasks, outperforms Llama-2 in text-only tasks while
being competitive with models such as Mixtral 8x7B and Gemini-Pro, and performs non-trivial image
generation, all in a single model. It also matches or exceeds the performance of much larger models,
including Gemini Pro and GPT-4V, according to human judgments on a new long-form mixed-modal
generation evaluation, where either the prompt or outputs contain mixed sequences of both images and
text. Chameleon marks a significant step forward in a unified modeling of full multimodal documents*
Chameleon incorporates a vector quantizer module to transform images into discrete tokens. That also enables image geenration using an auto-regressive transformer. Taken from the original paper.
This model was contributed by [joaogante](https://huggingface.co/joaogante) and [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/facebookresearch/chameleon).
## Usage tips
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating.
- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question.
- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor.
> [!NOTE]
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: ``.
## Usage example
### Single image inference
Here's how to load the model and perform inference in half-precision (`torch.float16`):
```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
import torch
from PIL import Image
import requests
processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
# prepare image and text prompt
url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "What color is the belt in this image?"
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True))
```
### Multi image inference
Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:
```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
import torch
from PIL import Image
import requests
processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")
# Get three different images
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image_stop = Image.open(requests.get(url, stream=True).raw)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_cats = Image.open(requests.get(url, stream=True).raw)
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image_snowman = Image.open(requests.get(url, stream=True).raw)
# Prepare a batched prompt, where the first one is a multi-image prompt and the second is not
prompts = [
"What do these images have in common?",
"What is shown in this image?"
]
# We can simply feed images in the order they have to be used in the text prompt
# Each "" token uses one image leaving the next for the subsequent "" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
# Generate
generate_ids = model.generate(**inputs, max_new_tokens=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
```
## Model optimization
### Quantization using Bitsandbytes
The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
```python
from transformers import ChameleonForCausalLM, BitsAndBytesConfig
# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
```
### Use Flash-Attention 2 and SDPA to further speed-up generation
The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
```python
from transformers import ChameleonForCausalLM
model = ChameleonForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2"
).to(0)
```
## ChameleonConfig
[[autodoc]] ChameleonConfig
## ChameleonVQVAEConfig
[[autodoc]] ChameleonVQVAEConfig
## ChameleonProcessor
[[autodoc]] ChameleonProcessor
## ChameleonImageProcessor
[[autodoc]] ChameleonImageProcessor
- preprocess
## ChameleonVQVAE
[[autodoc]] ChameleonVQVAE
- forward
## ChameleonModel
[[autodoc]] ChameleonModel
- forward
## ChameleonForCausalLM
[[autodoc]] ChameleonForCausalLM
- forward