mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 04:58:22 +06:00
[Pixtral] Improve docs, rename model (#33491)
* Improve docs, rename model * Fix style * Update repo id
This commit is contained in:
parent
c6379858f3
commit
06e27e3dc0
@ -255,7 +255,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
|
||||
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
|
||||
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
|
||||
| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ |
|
||||
| [Pixtral](model_doc/pixtral) | ✅ | ❌ | ❌ |
|
||||
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
|
||||
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
|
||||
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
|
||||
|
@ -18,20 +18,22 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
## Overview
|
||||
|
||||
The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!
|
||||
|
||||
The Pixtral model was released by the Mistral AI team on [vLLM](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!
|
||||
|
||||
Tips:
|
||||
|
||||
- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
|
||||
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
|
||||
- Pixtral is a multimodal model, taking images and text as input, and producing text as output.
|
||||
- This model follows the [Llava](llava) family, meaning image embeddings are placed instead of the `[IMG]` token placeholders. The model uses [`PixtralVisionModel`] for its vision encoder, and [`MistralForCausalLM`] for its language decoder.
|
||||
- The main contribution is the 2d ROPE (rotary postiion embeddings) on the images, and support for arbitrary image sizes (the images are not padded together nor are they resized).
|
||||
- The format for one or mulitple prompts is the following:
|
||||
```
|
||||
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
```
|
||||
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.
|
||||
|
||||
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
|
||||
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/vllm-project/vllm/pull/8377).
|
||||
|
||||
## Usage
|
||||
|
||||
Here is an example of how to run it:
|
||||
|
||||
@ -83,9 +85,9 @@ Each image captures a different scene, from a close-up of a dog to expansive nat
|
||||
|
||||
[[autodoc]] PixtralVisionConfig
|
||||
|
||||
## PixtralModel
|
||||
## PixtralVisionModel
|
||||
|
||||
[[autodoc]] PixtralModel
|
||||
[[autodoc]] PixtralVisionModel
|
||||
- forward
|
||||
|
||||
## PixtralImageProcessor
|
||||
|
@ -2994,7 +2994,7 @@ else:
|
||||
"Pix2StructVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
|
||||
_import_structure["models.pixtral"].extend(["PixtralVisionModel", "PixtralPreTrainedModel"])
|
||||
_import_structure["models.plbart"].extend(
|
||||
[
|
||||
"PLBartForCausalLM",
|
||||
@ -7486,8 +7486,8 @@ if TYPE_CHECKING:
|
||||
Pix2StructVisionModel,
|
||||
)
|
||||
from .models.pixtral import (
|
||||
PixtralModel,
|
||||
PixtralPreTrainedModel,
|
||||
PixtralVisionModel,
|
||||
)
|
||||
from .models.plbart import (
|
||||
PLBartForCausalLM,
|
||||
|
@ -195,7 +195,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonModel"),
|
||||
("phi", "PhiModel"),
|
||||
("phi3", "Phi3Model"),
|
||||
("pixtral", "PixtralModel"),
|
||||
("pixtral", "PixtralVisionModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("prophetnet", "ProphetNetModel"),
|
||||
|
@ -29,7 +29,7 @@ except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_pixtral"] = [
|
||||
"PixtralModel",
|
||||
"PixtralVisionModel",
|
||||
"PixtralPreTrainedModel",
|
||||
]
|
||||
|
||||
@ -53,8 +53,8 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
else:
|
||||
from .modeling_pixtral import (
|
||||
PixtralModel,
|
||||
PixtralPreTrainedModel,
|
||||
PixtralVisionModel,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -22,9 +22,9 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class PixtralVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
|
||||
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Pixtral-9B.
|
||||
This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an
|
||||
Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.
|
||||
|
||||
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
|
||||
|
||||
@ -52,19 +52,17 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
Dropout probability for the attention layers.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie the word embeddings with the input embeddings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
|
||||
>>> from transformers import PixtralVisionModel, PixtralVisionConfig
|
||||
|
||||
>>> # Initializing a Pixtral 12B style configuration
|
||||
>>> # Initializing a Pixtral-12B style configuration
|
||||
>>> config = PixtralVisionConfig()
|
||||
|
||||
>>> # Initializing a model from the pixtral 12B style configuration
|
||||
>>> model = PixtralModel(configuration)
|
||||
>>> # Initializing a model (with randomly initialized weights) from the configuration
|
||||
>>> model = PixtralVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
@ -84,7 +82,6 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
hidden_act="gelu",
|
||||
attention_dropout=0.0,
|
||||
rope_theta=10000.0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -99,5 +96,4 @@ class PixtralVisionConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.rope_theta = rope_theta
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -48,15 +48,13 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
||||
class PixtralRotaryEmbedding(nn.Module):
|
||||
"""
|
||||
The key with pixtral embedding is just that you have a frequency for each pixel positions.
|
||||
If you have height x width pixels (or embedding pixels)
|
||||
If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
|
||||
is given by indexing the pre_computed frequency on the width and height.
|
||||
|
||||
then the frequency used for ROPE is given by indexing the pre_computed frequency on the
|
||||
width and height.
|
||||
What you output is of dimension (batch, height * width, dim) with dim the embed dim.
|
||||
|
||||
What you output is of dimension batch, height * width, dim with dim the embed dim.
|
||||
|
||||
This simply means that for each image hidden states, you are going to add
|
||||
a corresponding positional embedding, based on it's index in the grid.
|
||||
This simply means that for each image hidden state, you are going to add
|
||||
a corresponding positional embedding, based on its index in the grid.
|
||||
"""
|
||||
|
||||
def __init__(self, config, device):
|
||||
@ -319,9 +317,7 @@ class PixtralTransformer(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
Embeddings which serve as input to the Transformer.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -392,17 +388,13 @@ PIXTRAL_START_DOCSTRING = r"""
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
config ([`PixtralVisionConfig`]):
|
||||
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
PIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
class PixtralPreTrainedModel(PreTrainedModel):
|
||||
config_class = PixtralVisionConfig
|
||||
base_model_prefix = "model"
|
||||
@ -412,9 +404,6 @@ class PixtralPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Pixtral isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
@ -433,8 +422,9 @@ class PixtralPreTrainedModel(PreTrainedModel):
|
||||
|
||||
PIXTRAL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values: list of N_img images of variable sizes,
|
||||
each of shape (C, H, W)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
|
||||
for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
@ -463,10 +453,10 @@ def generate_block_attention_mask(patch_embeds_list, tensor):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
|
||||
"The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
|
||||
PIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
class PixtralModel(PixtralPreTrainedModel):
|
||||
class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
base_model_prefix = "vision_encoder"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -7102,14 +7102,14 @@ class Pix2StructVisionModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PixtralModel(metaclass=DummyObject):
|
||||
class PixtralPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PixtralPreTrainedModel(metaclass=DummyObject):
|
||||
class PixtralVisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -21,8 +21,8 @@ import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PixtralModel,
|
||||
PixtralVisionConfig,
|
||||
PixtralVisionModel,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
@ -46,7 +46,7 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PixtralModelTester:
|
||||
class PixtralVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@ -107,7 +107,7 @@ class PixtralModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model = PixtralVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@ -120,7 +120,7 @@ class PixtralModelTester:
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_model_with_projection(self, config, pixel_values):
|
||||
model = PixtralModel(config=config)
|
||||
model = PixtralVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@ -140,17 +140,17 @@ class PixtralModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `PixtralModel`.
|
||||
Model tester for `PixtralVisionModel`.
|
||||
"""
|
||||
|
||||
all_model_classes = (PixtralModel,) if is_torch_available() else ()
|
||||
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PixtralModelTester(self)
|
||||
self.model_tester = PixtralVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip("model does not support input embeds")
|
||||
@ -261,7 +261,7 @@ class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class PixtralModelIntegrationTest(unittest.TestCase):
|
||||
class PixtralVisionModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
|
||||
@ -273,7 +273,7 @@ class PixtralModelIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
|
||||
|
||||
prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
|
||||
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
|
||||
|
Loading…
Reference in New Issue
Block a user