mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add MLLama (#33703)
* current changes
* nit
* Add cross_attenttion_mask to processor
* multi-image fixed
* Add cross_attenttion_mask to processor
* cross attn works in all cases
* WIP refactoring function for image processor
* WIP refactoring image processor functions
* Refactor preprocess to use global loops instead of list nested list comps
* Docstrings
* Add channels unification
* fix dtype issues
* Update docsrings and format
* Consistent max_image_tiles
* current script
* updates
* Add convert to rgb
* Add image processor tests
* updates!
* update
* god damn it I am dumb sometimes
* Precompute aspect ratios
* now this works, full match
* fix 😉
* nits
* style
* fix model and conversion
* nit
* nit
* kinda works
* hack for sdpa non-contiguous bias
* nits here and there
* latest c hanges
* merge?
* run forward
* Add aspect_ratio_mask
* vision attention mask
* update script and config variable names
* nit
* nits
* be able to load
* style
* nits
* there
* nits
* make forward run
* small update
* enable generation multi-turn
* nit
* nit
* Clean up a bit for errors and typos
* A bit more constant fixes
* 90B keys and shapes match
* Fix for 11B model
* Fixup, remove debug part
* Docs
* Make max_aspect_ratio_id to be minimal
* Update image processing code to match new implementation
* Adjust conversion for final checkpoint state
* Change dim in repeat_interleave (accordig to meta code)
* tmp fix for num_tiles
* Fix for conversion (gate<->up, q/k_proj rope permute)
* nits
* codestyle
* Vision encoder fixes
* pass cross attn mask further
* Refactor aspect ratio mask
* Disable text-only generation
* Fix cross attention layers order, remove q/k norm rotation for cross atention layers
* Refactor gated position embeddings
* fix bugs but needs test with new weights
* rope scaling should be llama3
* Fix rope scaling name
* Remove debug for linear layer
* fix copies
* Make mask prepare private func
* Remove linear patch embed
* Make precomputed embeddings as nn.Embedding module
* MllamaPrecomputedAspectRatioEmbedding with config init
* Remove unused self.output_dim
* nit, intermediate layers
* Rename ln and pos_embed
* vision_chunk_size -> image_size
* return_intermediate -> intermediate_layers_indices
* vision_input_dim -> hidden_size
* Fix copied from statements
* fix most tests
* Fix more copied from
* layer_id->layer_idx
* Comment
* Fix tests for processor
* Copied from for _prepare_4d_causal_attention_mask_with_cache_position
* Style fix
* Add MllamaForCausalLM
* WIP fixing tests
* Remove duplicated layers
* Remove dummy file
* Fix style
* Fix consistency
* Fix some TODOs
* fix language_model instantiation, add docstring
* Move docstring, remove todos for precomputed embeds (we cannot init them properly)
* Add initial docstrings
* Fix
* fix some tests
* lets skip these
* nits, remove print, style
* Add one more copied from
* Improve test message
* Make validate func private
* Fix dummy objects
* Refactor `data_format` a bit + add comment
* typos/nits
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
* fix dummy objects and imports
* Add chat template config json
* remove num_kv_heads from vision attention
* fix
* move some commits and add more tests
* fix test
* Remove `update_key_name` from modeling utils
* remove num-kv-heads again
* some prelimiary docs
* Update chat template + tests
* nit, conversion script max_num_tiles from params
* Fix warning for text-only generation
* Update conversion script for instruct models
* Update chat template in converstion + test
* add tests for CausalLM model
* model_max_length, avoid null chat_template
* Refactor conversion script
* Fix forward
* Fix integration tests
* Refactor vision config + docs
* Fix default
* Refactor text config
* Doc fixes
* Remove unused args, fix docs example
* Squashed commit of the following:
commit b51ce5a2efffbecdefbf6fc92ee87372ec9d8830
Author: qubvel <qubvel@gmail.com>
Date: Wed Sep 18 13:39:15 2024 +0000
Move model + add output hidden states and output attentions
* Fix num_channels
* Add mllama text and mllama vision models
* Fixing repo consistency
* Style fix
* Fixing repo consistency
* Fixing unused config params
* Fix failed tests after refactoring
* hidden_activation -> hidden_act for text mlp
* Remove from_pretrained from sub-configs
* Apply suggestions from code review
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Reuse lambda in conversion script
* Remove run.py
* Update docs/source/en/model_doc/mllama.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/mllama/processing_mllama.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Remove unused LlamaTokenizerFast
* Fix logging
* Refactor gating
* Remove cycle for collecting intermediate states
* Refactor text-only check, add integration test for text-only
* Revert from pretrained to configs
* Fix example
* Add auto `bos_token` adding in processor
* Fix tips
* Update src/transformers/models/auto/tokenization_auto.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Enable supports_gradient_checkpointing model flag
* add eager/sdpa options
* don't skip attn tests and bring back GC skips (did i really remove those?)
* Fix signature, but get error with None gradient
* Fix output attention tests
* Disable GC back
* Change no split modules
* Fix dropout
* Style
* Add Mllama to sdpa list
* Add post init for vision model
* Refine config for MllamaForCausalLMModelTest and skipped tests for CausalLM model
* if skipped, say it, don't pass
* Clean vision tester config
* Doc for args
* Update tests/models/mllama/test_modeling_mllama.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Add cross_attention_mask to test
* typehint
* Remove todo
* Enable gradient checkpointing
* Docstring
* Style
* Fixing and skipping some tests for new cache
* Mark flaky test
* Skip `test_sdpa_can_compile_dynamic` test
* Fixing some offload tests
* Add direct GenerationMixin inheritance
* Remove unused code
* Add initializer_range to vision config
* update the test to make sure we show if split
* fix gc?
* Fix repo consistency
* Undo modeling utils debug changes
* Fix link
* mllama -> Mllama
* [mllama] -> [Mllama]
* Enable compile test for CausalLM model (text-only)
* Fix TextModel prefix
* Update doc
* Docs for forward, type hints, and vision model prefix
* make sure to reset
* fix init
* small script refactor and styling
* nit
* updates!
* some nits
* Interpolate embeddings for 560 size and update integration tests
* nit
* does not suppor static cache!
* update
* fix
* nit2
* this?
* Fix conversion
* Style
* 4x memory improvement with image cache AFAIK
* Token decorator for tests
* Skip failing tests
* update processor errors
* fix split issues
* style
* weird
* style
* fix failing tests
* update
* nit fixing the whisper tests
* fix path
* update
---------
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: pavel <ubuntu@ip-10-90-0-11.ec2.internal>
Co-authored-by: qubvel <qubvel@gmail.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
94f18cf23c
commit
19d58d31f1
@ -860,6 +860,8 @@
|
||||
title: MatCha
|
||||
- local: model_doc/mgp-str
|
||||
title: MGP-STR
|
||||
- local: model_doc/mllama
|
||||
title: mllama
|
||||
- local: model_doc/nougat
|
||||
title: Nougat
|
||||
- local: model_doc/omdet-turbo
|
||||
|
@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ |
|
||||
| [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ |
|
||||
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
|
||||
| [Mllama](model_doc/mllama) | ✅ | ❌ | ❌ |
|
||||
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
|
||||
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
|
||||
| [MobileBERT](model_doc/mobilebert) | ✅ | ✅ | ❌ |
|
||||
|
124
docs/source/en/model_doc/mllama.md
Normal file
124
docs/source/en/model_doc/mllama.md
Normal file
@ -0,0 +1,124 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mllama
|
||||
|
||||
## Overview
|
||||
|
||||
The Llama 3.2-Vision collection of multimodal large language models (LLMs) is a collection of pretrained and instruction-tuned image reasoning generative models in 11B and 90B sizes (text \+ images in / text out). The Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image.
|
||||
|
||||
**Model Architecture:** Llama 3.2-Vision is built on top of Llama 3.1 text-only model, which is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human preferences for helpfulness and safety. To support image recognition tasks, the Llama 3.2-Vision model uses a separately trained vision adapter that integrates with the pre-trained Llama 3.1 language model. The adapter consists of a series of cross-attention layers that feed image encoder representations into the core LLM.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
- For image+text and text inputs use `MllamaForConditionalGeneration`.
|
||||
- For text-only inputs use `MllamaForCausalLM` for generation to avoid loading vision tower.
|
||||
- Each sample can contain multiple images, and the number of images can vary between samples. The processor will pad the inputs to the maximum number of images across samples and to a maximum number of tiles within each image.
|
||||
- The text passed to the processor should have the `"<|image|>"` tokens where the images should be inserted.
|
||||
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor.
|
||||
|
||||
## Usage Example
|
||||
|
||||
#### Instruct model
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import MllamaForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What does the image show?"}
|
||||
]
|
||||
}
|
||||
],
|
||||
]
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
url = "https://llava-vl.github.io/static/images/view.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
|
||||
output = model.generate(**inputs, max_new_tokens=25)
|
||||
print(processor.decode(output[0]))
|
||||
```
|
||||
|
||||
#### Base model
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import MllamaForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-11B-Vision"
|
||||
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "<|image|>If I had to write a haiku for this one"
|
||||
url = "https://llava-vl.github.io/static/images/view.jpg"
|
||||
raw_image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(model.device)
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
|
||||
## MllamaConfig
|
||||
|
||||
[[autodoc]] MllamaConfig
|
||||
|
||||
## MllamaProcessor
|
||||
|
||||
[[autodoc]] MllamaProcessor
|
||||
|
||||
|
||||
## MllamaImageProcessor
|
||||
|
||||
[[autodoc]] MllamaImageProcessor
|
||||
|
||||
## MllamaForConditionalGeneration
|
||||
|
||||
[[autodoc]] MllamaForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## MllamaForCausalLM
|
||||
|
||||
[[autodoc]] MllamaForCausalLM
|
||||
- forward
|
||||
|
||||
## MllamaTextModel
|
||||
|
||||
[[autodoc]] MllamaTextModel
|
||||
- forward
|
||||
|
||||
## MllamaForCausalLM
|
||||
|
||||
[[autodoc]] MllamaForCausalLM
|
||||
- forward
|
||||
|
||||
## MllamaVisionModel
|
||||
|
||||
[[autodoc]] MllamaVisionModel
|
||||
- forward
|
@ -236,6 +236,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
|
||||
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
|
@ -577,6 +577,10 @@ _import_structure = {
|
||||
"models.mimi": ["MimiConfig"],
|
||||
"models.mistral": ["MistralConfig"],
|
||||
"models.mixtral": ["MixtralConfig"],
|
||||
"models.mllama": [
|
||||
"MllamaConfig",
|
||||
"MllamaProcessor",
|
||||
],
|
||||
"models.mluke": [],
|
||||
"models.mobilebert": [
|
||||
"MobileBertConfig",
|
||||
@ -1199,6 +1203,7 @@ else:
|
||||
)
|
||||
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
|
||||
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
|
||||
_import_structure["models.mllama"].extend(["MllamaImageProcessor"])
|
||||
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
|
||||
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
|
||||
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
|
||||
@ -2704,6 +2709,16 @@ else:
|
||||
"MixtralPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mllama"].extend(
|
||||
[
|
||||
"MllamaForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"MllamaPreTrainedModel",
|
||||
"MllamaProcessor",
|
||||
"MllamaTextModel",
|
||||
"MllamaVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mobilebert"].extend(
|
||||
[
|
||||
"MobileBertForMaskedLM",
|
||||
@ -5377,6 +5392,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.mistral import MistralConfig
|
||||
from .models.mixtral import MixtralConfig
|
||||
from .models.mllama import (
|
||||
MllamaConfig,
|
||||
MllamaProcessor,
|
||||
)
|
||||
from .models.mobilebert import (
|
||||
MobileBertConfig,
|
||||
MobileBertTokenizer,
|
||||
@ -6037,6 +6056,7 @@ if TYPE_CHECKING:
|
||||
MaskFormerFeatureExtractor,
|
||||
MaskFormerImageProcessor,
|
||||
)
|
||||
from .models.mllama import MllamaImageProcessor
|
||||
from .models.mobilenet_v1 import (
|
||||
MobileNetV1FeatureExtractor,
|
||||
MobileNetV1ImageProcessor,
|
||||
@ -7270,6 +7290,14 @@ if TYPE_CHECKING:
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
from .models.mllama import (
|
||||
MllamaForCausalLM,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaPreTrainedModel,
|
||||
MllamaProcessor,
|
||||
MllamaTextModel,
|
||||
MllamaVisionModel,
|
||||
)
|
||||
from .models.mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
|
@ -80,10 +80,12 @@ class Cache(torch.nn.Module):
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
device = self.key_cache[layer_idx].device
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
device = self.value_cache[layer_idx].device
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
if self.key_cache[layer_idx] != []:
|
||||
device = self.key_cache[layer_idx].device
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
if self.value_cache[layer_idx] != []:
|
||||
device = self.value_cache[layer_idx].device
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
|
||||
@property
|
||||
def seen_tokens(self):
|
||||
@ -358,10 +360,14 @@ class DynamicCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
if num_hidden_layers is None:
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
else:
|
||||
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
|
||||
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
|
||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
@ -420,6 +426,11 @@ class DynamicCache(Cache):
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
# content on layer cache can be a tensor and checking not tensor causes errors
|
||||
# so we explicitly check for the empty list
|
||||
elif self.key_cache[layer_idx] == []:
|
||||
self.key_cache[layer_idx] = key_states
|
||||
self.value_cache[layer_idx] = value_states
|
||||
else:
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
@ -429,7 +440,7 @@ class DynamicCache(Cache):
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
# TODO: deprecate this function in favor of `cache_position`
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
@ -446,10 +457,12 @@ class DynamicCache(Cache):
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
|
||||
) -> "DynamicCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls()
|
||||
cache = cls(num_hidden_layers)
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx]
|
||||
@ -468,15 +481,16 @@ class DynamicCache(Cache):
|
||||
|
||||
self._seen_tokens = max_length
|
||||
for idx in range(len(self.key_cache)):
|
||||
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
||||
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
||||
if self.key_cache[idx] != []:
|
||||
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
||||
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
||||
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
|
||||
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
out = []
|
||||
for i in range(0, full_batch_size, split_size):
|
||||
current_split = DynamicCache()
|
||||
current_split = DynamicCache(num_hidden_layers)
|
||||
current_split._seen_tokens = self._seen_tokens
|
||||
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
|
||||
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
|
||||
@ -484,14 +498,17 @@ class DynamicCache(Cache):
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
|
||||
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
cache = cls()
|
||||
cache = cls(num_hidden_layers)
|
||||
for idx in range(len(splits[0])):
|
||||
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
|
||||
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
|
||||
cache.update(layer_keys, layer_values, idx)
|
||||
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
||||
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
||||
if key_cache != []:
|
||||
layer_keys = torch.cat(key_cache, dim=0)
|
||||
layer_values = torch.cat(value_cache, dim=0)
|
||||
cache.update(layer_keys, layer_values, idx)
|
||||
return cache
|
||||
|
||||
def batch_repeat_interleave(self, repeats: int):
|
||||
@ -1391,10 +1408,13 @@ class EncoderDecoderCache(Cache):
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
|
||||
) -> "EncoderDecoderCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
|
||||
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
|
||||
cache = cls(
|
||||
self_attention_cache=DynamicCache(num_hidden_layers),
|
||||
cross_attention_cache=DynamicCache(num_hidden_layers),
|
||||
)
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx][:2]
|
||||
@ -1407,7 +1427,10 @@ class EncoderDecoderCache(Cache):
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
if len(self.self_attention_cache.key_cache) <= layer_idx:
|
||||
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
|
||||
if self.self_attention_cache.key_cache == []:
|
||||
return 0
|
||||
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
|
||||
return 0
|
||||
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
@ -1448,12 +1471,14 @@ class EncoderDecoderCache(Cache):
|
||||
self.check_dynamic_cache(self.crop.__name__)
|
||||
self.self_attention_cache.crop(maximum_length)
|
||||
|
||||
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
|
||||
def batch_split(
|
||||
self, full_batch_size: int, split_size: int, num_hidden_layers: int
|
||||
) -> "List[EncoderDecoderCache]":
|
||||
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||
`_split_model_inputs()` in `generation.utils`"""
|
||||
self.check_dynamic_cache(self.batch_split.__name__)
|
||||
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
|
||||
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
|
||||
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||
|
||||
out = []
|
||||
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
|
||||
@ -1461,11 +1486,11 @@ class EncoderDecoderCache(Cache):
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
|
||||
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
|
||||
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||
`generation.utils`"""
|
||||
self_attention_cache = DynamicCache()
|
||||
cross_attention_cache = DynamicCache()
|
||||
self_attention_cache = DynamicCache(num_hidden_layers)
|
||||
cross_attention_cache = DynamicCache(num_hidden_layers)
|
||||
for idx in range(len(splits[0])):
|
||||
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
|
||||
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
|
||||
|
@ -398,12 +398,15 @@ def _crop_past_key_values(model, past_key_values, max_length):
|
||||
past_key_values.crop(max_length)
|
||||
elif past_key_values is not None:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :max_length, :],
|
||||
past_key_values[idx][1][:, :, :max_length, :],
|
||||
if past_key_values[idx] != ([], []):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :max_length, :],
|
||||
past_key_values[idx][1][:, :, :max_length, :],
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
|
||||
past_key_values = tuple(new_past)
|
||||
return past_key_values
|
||||
|
||||
|
@ -32,6 +32,7 @@ from ..cache_utils import (
|
||||
OffloadedCache,
|
||||
QuantizedCacheConfig,
|
||||
)
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
@ -1601,10 +1602,11 @@ class GenerationMixin:
|
||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||
# keeps copying the cache thus using much more memory
|
||||
else:
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache()
|
||||
DynamicCache(num_hidden_layers)
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
|
||||
)
|
||||
|
||||
def _supports_num_logits_to_keep(self) -> bool:
|
||||
@ -2384,11 +2386,7 @@ class GenerationMixin:
|
||||
this_peer_finished = False
|
||||
|
||||
# prepare layers for DoLa decoding
|
||||
final_layer = (
|
||||
self.config.text_config.num_hidden_layers
|
||||
if hasattr(self.config, "text_config")
|
||||
else self.config.num_hidden_layers
|
||||
)
|
||||
final_layer = self.config.get_text_config().num_hidden_layers
|
||||
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
||||
# as the early exit from word embeddings will become identity function
|
||||
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
||||
@ -2736,7 +2734,7 @@ class GenerationMixin:
|
||||
model_kwargs["past_key_values"].crop(-1)
|
||||
|
||||
all_outputs.append(outputs)
|
||||
outputs = stack_model_outputs(all_outputs)
|
||||
outputs = stack_model_outputs(all_outputs, self.config.get_text_config())
|
||||
|
||||
else:
|
||||
# compute the candidate tokens by the language model and collect their hidden_states
|
||||
@ -3014,8 +3012,7 @@ class GenerationMixin:
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = outputs.logits.clone()[:, -1, :].float()
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
@ -3242,13 +3239,16 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
inputs_per_sub_batches = _split_model_inputs(
|
||||
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
|
||||
model_inputs,
|
||||
split_size=batch_size,
|
||||
full_batch_size=batch_beam_size,
|
||||
config=self.config.get_text_config(),
|
||||
)
|
||||
outputs_per_sub_batch = [
|
||||
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
|
||||
]
|
||||
|
||||
outputs = stack_model_outputs(outputs_per_sub_batch)
|
||||
outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
|
||||
|
||||
else: # Unchanged original behavior
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
@ -4004,7 +4004,7 @@ class GenerationMixin:
|
||||
isinstance(past_key_values, EncoderDecoderCache)
|
||||
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
||||
):
|
||||
if len(past_key_values) == 0:
|
||||
if past_key_values.get_seq_length() == 0:
|
||||
start_from_empty_dynamic_cache = True
|
||||
|
||||
this_peer_finished = False
|
||||
@ -4313,7 +4313,7 @@ def _ranking_fast(
|
||||
return selected_idx
|
||||
|
||||
|
||||
def _split(data, full_batch_size: int, split_size: int = None):
|
||||
def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None):
|
||||
"""
|
||||
Takes care of three cases:
|
||||
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
|
||||
@ -4331,7 +4331,7 @@ def _split(data, full_batch_size: int, split_size: int = None):
|
||||
elif isinstance(data, DynamicCache) or (
|
||||
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||
):
|
||||
return data.batch_split(full_batch_size, split_size)
|
||||
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||
elif isinstance(data, tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0], tuple):
|
||||
@ -4350,7 +4350,7 @@ def _split(data, full_batch_size: int, split_size: int = None):
|
||||
|
||||
|
||||
def _split_model_inputs(
|
||||
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
|
||||
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int, config: PretrainedConfig
|
||||
) -> List[Union[ModelOutput, Dict]]:
|
||||
"""
|
||||
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
|
||||
@ -4384,16 +4384,20 @@ def _split_model_inputs(
|
||||
keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
|
||||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
|
||||
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
|
||||
# we split the tensors and tuples of tensors
|
||||
data_split_list = [
|
||||
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
|
||||
{k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys}
|
||||
for i in range(full_batch_size // split_size)
|
||||
]
|
||||
# bool values are the same and replicated for each split
|
||||
bool_data = {k: model_input[k] for k in bool_keys}
|
||||
# encoder_outputs is a ModelOutput object and should be split by its own
|
||||
if "encoder_outputs" in model_input:
|
||||
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
|
||||
encoder_outputs_split = _split_model_inputs(
|
||||
model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config()
|
||||
)
|
||||
data_split_list = [
|
||||
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
|
||||
]
|
||||
@ -4411,7 +4415,7 @@ def _split_model_inputs(
|
||||
return split_model_inputs
|
||||
|
||||
|
||||
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConfig) -> ModelOutput:
|
||||
"""
|
||||
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
||||
specific ModelOutput subclass from the list provided.
|
||||
@ -4421,6 +4425,7 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
|
||||
# Infer the class from the first object in the list
|
||||
model_output_cls = type(model_outputs[0])
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
|
||||
# Ensure all objects are of the same type
|
||||
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
||||
@ -4437,9 +4442,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
return torch.cat(data, dim=0)
|
||||
# New cache format
|
||||
elif isinstance(data[0], DynamicCache):
|
||||
return DynamicCache.from_batch_splits(data)
|
||||
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
elif isinstance(data[0], EncoderDecoderCache):
|
||||
return EncoderDecoderCache.from_batch_splits(data)
|
||||
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
|
@ -153,6 +153,7 @@ from . import (
|
||||
mimi,
|
||||
mistral,
|
||||
mixtral,
|
||||
mllama,
|
||||
mluke,
|
||||
mobilebert,
|
||||
mobilenet_v1,
|
||||
|
@ -172,6 +172,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mimi", "MimiConfig"),
|
||||
("mistral", "MistralConfig"),
|
||||
("mixtral", "MixtralConfig"),
|
||||
("mllama", "MllamaConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("mobilenet_v1", "MobileNetV1Config"),
|
||||
("mobilenet_v2", "MobileNetV2Config"),
|
||||
@ -477,6 +478,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mimi", "Mimi"),
|
||||
("mistral", "Mistral"),
|
||||
("mixtral", "Mixtral"),
|
||||
("mllama", "Mllama"),
|
||||
("mluke", "mLUKE"),
|
||||
("mms", "MMS"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
|
@ -103,6 +103,7 @@ else:
|
||||
("mask2former", ("Mask2FormerImageProcessor",)),
|
||||
("maskformer", ("MaskFormerImageProcessor",)),
|
||||
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("mllama", ("MllamaImageProcessor",)),
|
||||
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
|
||||
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
|
||||
("mobilevit", ("MobileViTImageProcessor",)),
|
||||
|
@ -327,6 +327,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("mamba2", "Mamba2ForCausalLM"),
|
||||
("mega", "MegaForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
@ -500,6 +501,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mistral", "MistralForCausalLM"),
|
||||
("mixtral", "MixtralForCausalLM"),
|
||||
("mllama", "MllamaForCausalLM"),
|
||||
("mpt", "MptForCausalLM"),
|
||||
("musicgen", "MusicgenForCausalLM"),
|
||||
("musicgen_melody", "MusicgenMelodyForCausalLM"),
|
||||
@ -566,6 +568,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("hiera", "HieraModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("levit", "LevitModel"),
|
||||
("mllama", "MllamaVisionModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
("mobilevit", "MobileViTModel"),
|
||||
@ -737,6 +740,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||
@ -1338,6 +1342,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("mllama", "MllamaTextModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("mt5", "MT5EncoderModel"),
|
||||
("nystromformer", "NystromformerModel"),
|
||||
|
@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("markuplm", "MarkupLMProcessor"),
|
||||
("mctct", "MCTCTProcessor"),
|
||||
("mgp-str", "MgpstrProcessor"),
|
||||
("mllama", "MllamaProcessor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
|
@ -305,6 +305,7 @@ else:
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
84
src/transformers/models/mllama/__init__.py
Normal file
84
src/transformers/models/mllama/__init__.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_mllama": ["MllamaConfig"],
|
||||
"processing_mllama": ["MllamaProcessor"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mllama"] = [
|
||||
"MllamaForConditionalGeneration",
|
||||
"MllamaForCausalLM",
|
||||
"MllamaTextModel",
|
||||
"MllamaVisionModel",
|
||||
"MllamaPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_mllama"] = ["MllamaImageProcessor"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mllama import MllamaConfig
|
||||
from .processing_mllama import MllamaProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mllama import (
|
||||
MllamaForCausalLM,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaPreTrainedModel,
|
||||
MllamaTextModel,
|
||||
MllamaVisionModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_mllama import (
|
||||
MllamaImageProcessor,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
400
src/transformers/models/mllama/configuration_mllama.py
Normal file
400
src/transformers/models/mllama/configuration_mllama.py
Normal file
@ -0,0 +1,400 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mllama model configuration"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MllamaVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MllamaVisionModel`]. It is used to instantiate an
|
||||
Mllama vision model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Mllama-11B.
|
||||
|
||||
e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1280):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_global_layers (`int`, *optional*, defaults to 8):
|
||||
Number of global layers in the Transformer encoder.
|
||||
Vision model has a second transformer encoder, called global.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input image.
|
||||
intermediate_size (`int`, *optional*, defaults to 5120):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
vision_output_dim (`int`, *optional*, defaults to 7680):
|
||||
Dimensionality of the vision model output. Includes output of transformer
|
||||
encoder with intermediate layers and global transformer encoder.
|
||||
image_size (`int`, *optional*, defaults to 448):
|
||||
The size (resolution) of each image *tile*.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
max_num_tiles (`int`, *optional*, defaults to 4):
|
||||
Maximum number of tiles for image splitting.
|
||||
intermediate_layers_indices (`List[int]`, *optional*, defaults to [3, 7, 15, 23, 30]):
|
||||
Indices of intermediate layers of transformer encoder from which to extract and output features.
|
||||
These output features are concatenated with final hidden state of transformer encoder.
|
||||
supported_aspect_ratios (`List[List[int]]`, *optional*):
|
||||
List of supported aspect ratios for image splitting. If not specified, the default supported aspect ratios
|
||||
are [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1], [2, 2], [3, 1], [4, 1]] for `max_num_tiles=4`.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MllamaVisionConfig, MllamaVisionModel
|
||||
|
||||
>>> # Initializing a Llama config
|
||||
>>> config = MllamaVisionConfig()
|
||||
|
||||
>>> # Initializing a vision model from the mllama-11b style configuration
|
||||
>>> model = MllamaVisionModel(config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mllama_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1280,
|
||||
hidden_act: str = "gelu",
|
||||
num_hidden_layers: int = 32,
|
||||
num_global_layers: int = 8,
|
||||
num_attention_heads: int = 16,
|
||||
num_channels: int = 3,
|
||||
intermediate_size: int = 5120,
|
||||
vision_output_dim: int = 7680,
|
||||
image_size: int = 448,
|
||||
patch_size: int = 14,
|
||||
norm_eps: float = 1e-5,
|
||||
max_num_tiles: int = 4,
|
||||
intermediate_layers_indices: Optional[List[int]] = None,
|
||||
supported_aspect_ratios: Optional[List[List[int]]] = None,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if supported_aspect_ratios is None:
|
||||
if max_num_tiles != 4:
|
||||
raise ValueError("max_num_tiles must be 4 for default supported aspect ratios")
|
||||
supported_aspect_ratios = [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1], [2, 2], [3, 1], [4, 1]]
|
||||
|
||||
if intermediate_layers_indices is None:
|
||||
intermediate_layers_indices = [3, 7, 15, 23, 30]
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_channels = num_channels
|
||||
self.intermediate_size = intermediate_size
|
||||
self.image_size = image_size
|
||||
self.vision_output_dim = vision_output_dim
|
||||
self.patch_size = patch_size
|
||||
self.intermediate_layers_indices = intermediate_layers_indices
|
||||
self.num_global_layers = num_global_layers
|
||||
self.max_num_tiles = max_num_tiles
|
||||
self.norm_eps = norm_eps
|
||||
self.attention_heads = num_attention_heads
|
||||
self.supported_aspect_ratios = supported_aspect_ratios
|
||||
self.initializer_range = initializer_range
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def max_aspect_ratio_id(self) -> int:
|
||||
return len(self.supported_aspect_ratios)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "mllama":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class MllamaTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MllamaTextModel`]. It is used to instantiate an
|
||||
Mllama text model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Mllama-11B.
|
||||
|
||||
e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128256):
|
||||
Vocabulary size of the Mllama text model. Defines the maximum number of different tokens that can be represented
|
||||
by the `inputs_ids` passed when calling [`MllamaTextModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 40):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
|
||||
specified, will default to `num_attention_heads`.
|
||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
rope_theta (`float`, *optional*, defaults to 500000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
cross_attention_layers (`List[int]`, *optional*):
|
||||
Indices of the cross attention layers. If not specified, will default to [3, 8, 13, 18, 23, 28, 33, 38].
|
||||
dropout (`float`, *optional*, defaults to 0):
|
||||
The dropout probability for self- and cross-attention layers.
|
||||
bos_token_id (`int`, *optional*, defaults to 128000):
|
||||
The id of the beginning of sentence token.
|
||||
eos_token_id (`int`, *optional*, defaults to 128001):
|
||||
The id of the end of sentence token.
|
||||
pad_token_id (`int`, *optional*, defaults to 128004):
|
||||
The id of the padding token.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MllamaTextModel, MllamaTextConfig
|
||||
|
||||
>>> # Initializing a Mllama text config
|
||||
>>> config = MllamaTextConfig()
|
||||
|
||||
>>> # Initializing a model from the Mllama text configuration
|
||||
>>> model = MllamaTextModel(config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mllama_text_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 128256,
|
||||
hidden_size: int = 4096,
|
||||
hidden_act: str = "silu",
|
||||
num_hidden_layers: int = 40,
|
||||
num_attention_heads: int = 32,
|
||||
num_key_value_heads: int = 8,
|
||||
intermediate_size: int = 14_336,
|
||||
rope_theta: float = 500_000,
|
||||
rope_scaling: Optional[Dict] = None,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
max_position_embeddings: int = 131_072,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache: bool = True,
|
||||
tie_word_embeddings: bool = False,
|
||||
cross_attention_layers: Optional[List[int]] = None,
|
||||
dropout: float = 0,
|
||||
bos_token_id: int = 128000,
|
||||
eos_token_id: int = 128001,
|
||||
pad_token_id: Optional[int] = 128004,
|
||||
**kwargs,
|
||||
):
|
||||
if cross_attention_layers is None:
|
||||
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.cross_attention_layers = cross_attention_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "mllama":
|
||||
config_dict = config_dict["text_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class MllamaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an
|
||||
Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the Mllama-9B.
|
||||
|
||||
e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
image_token_index (`int`, *optional*, defaults to 128256):
|
||||
The image token index to encode the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig
|
||||
|
||||
>>> # Initializing a CLIP-vision config
|
||||
>>> vision_config = MllamaVisionConfig()
|
||||
|
||||
>>> # Initializing a Llama config
|
||||
>>> text_config = MllamaTextConfig()
|
||||
|
||||
>>> # Initializing a mllama-11b style configuration
|
||||
>>> configuration = MllamaConfig(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the mllama-11b style configuration
|
||||
>>> model = MllamaForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mllama"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
image_token_index=128256,
|
||||
**kwargs,
|
||||
):
|
||||
if vision_config is None:
|
||||
self.vision_config = MllamaVisionConfig()
|
||||
logger.info("vision_config is None, using default mllama vision config")
|
||||
elif isinstance(vision_config, dict):
|
||||
self.vision_config = MllamaVisionConfig(**vision_config)
|
||||
elif isinstance(vision_config, MllamaVisionConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.image_token_index = image_token_index
|
||||
|
||||
if text_config is None:
|
||||
self.text_config = MllamaTextConfig()
|
||||
logger.info("text_config is None, using default mllama text config")
|
||||
elif isinstance(text_config, dict):
|
||||
self.text_config = MllamaTextConfig(**text_config)
|
||||
elif isinstance(text_config, MllamaTextConfig):
|
||||
self.text_config = text_config
|
||||
|
||||
super().__init__(**kwargs)
|
635
src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Normal file
635
src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Normal file
@ -0,0 +1,635 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
MllamaConfig,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaImageProcessor,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import TikTokenConverter
|
||||
from transformers.models.mllama.configuration_mllama import MllamaTextConfig, MllamaVisionConfig
|
||||
from transformers.models.mllama.image_processing_mllama import get_all_supported_aspect_ratios
|
||||
|
||||
|
||||
# fmt: off
|
||||
# If a weight needs to be split in two or more keys, use `|` to indicate it. ex:
|
||||
# r"text_model.layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.q|k|v|_proj.weight"
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"text_model.norm.weight": r"language_model.model.norm.weight",
|
||||
r"text_model.output.weight": r"language_model.lm_head.weight",
|
||||
r"text_model.tok_embeddings": r"language_model.model.embed_tokens",
|
||||
r"text_model.learnable_embedding": r"language_model.model.learnable_embedding",
|
||||
r"text_model.rope.freqs": None, # meaning we skip it and don't want it
|
||||
# For every cross attention layer, the layer needs to be updated
|
||||
r"text_model.cross_attention_layers.(\d+).gate_attn": r"language_model.model.layers.\1.cross_attn_attn_gate",
|
||||
r"text_model.cross_attention_layers.(\d+).gate_ffwd": r"language_model.model.layers.\1.cross_attn_mlp_gate",
|
||||
# special key, wqkv needs to be split afterwards
|
||||
r"text_model.cross_attention_layers.(\d+).attention.w(q|k|v|o)": r"language_model.model.layers.\1.cross_attn.\2_proj",
|
||||
r"text_model.cross_attention_layers.(\d+).attention.(q|k)_norm": r"language_model.model.layers.\1.cross_attn.\2_norm",
|
||||
r"text_model.cross_attention_layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||
r"text_model.cross_attention_layers.(\d+).attention.wk.layer_norm_weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
r"text_model.cross_attention_layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight",
|
||||
r"text_model.cross_attention_layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight",
|
||||
r"text_model.cross_attention_layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight",
|
||||
r"text_model.cross_attention_layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
# self attention layers
|
||||
r"text_model.layers.(\d+).attention.w(q|k|v|o).weight": r"language_model.model.layers.\1.self_attn.\2_proj.weight",
|
||||
r"text_model.layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
|
||||
r"text_model.layers.(\d+).feed_forward.w1.": r"language_model.model.layers.\1.mlp.gate_proj.",
|
||||
r"text_model.layers.(\d+).feed_forward.w2.": r"language_model.model.layers.\1.mlp.down_proj.",
|
||||
r"text_model.layers.(\d+).feed_forward.w3.": r"language_model.model.layers.\1.mlp.up_proj.",
|
||||
r"text_model.layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
# Vision encoder mapping
|
||||
r"vision_model.vision_encoder.conv1._linear": r"vision_model.patch_embedding",
|
||||
r'vision_model.vision_projection.': r"multi_modal_projector.",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).attn.wq": r"vision_model.\1.layers.\2.self_attn.q_proj",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).attn.wk": r"vision_model.\1.layers.\2.self_attn.k_proj",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).attn.wv": r"vision_model.\1.layers.\2.self_attn.v_proj",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).attn.wo": r"vision_model.\1.layers.\2.self_attn.o_proj",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).mlp.c_fc": r"vision_model.\1.layers.\2.mlp.fc1",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).mlp.c_proj": r"vision_model.\1.layers.\2.mlp.fc2",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_1": r"vision_model.\1.layers.\2.input_layernorm",
|
||||
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_2": r"vision_model.\1.layers.\2.post_attention_layernorm",
|
||||
r"vision_model.vision_encoder.global_transformer.resblocks.(\d+).(gate_ffn|gate_attn)": r"vision_model.global_transformer.layers.\1.\2",
|
||||
r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)': r'vision_model.vision_encoder.layernorm_\1.\2',
|
||||
r'vision_model.vision_encoder.positional_embedding\b': r'vision_model.gated_positional_embedding.embedding',
|
||||
r'vision_model.vision_encoder.gated_positional_embedding\b': r'vision_model.gated_positional_embedding.tile_embedding.weight',
|
||||
r'vision_model.vision_encoder.gated_positional_embedding_gate': r'vision_model.gated_positional_embedding.gate',
|
||||
r"vision_model.vision_encoder.pre_tile_pos_embed.embedding": r"vision_model.pre_tile_positional_embedding.embedding.weight",
|
||||
r"vision_model.vision_encoder.post_tile_pos_embed.embedding": r"vision_model.post_tile_positional_embedding.embedding.weight",
|
||||
r"vision_model.vision_encoder.pre_tile_pos_embed.gate": r"vision_model.pre_tile_positional_embedding.gate",
|
||||
r"vision_model.vision_encoder.post_tile_pos_embed.gate": r"vision_model.post_tile_positional_embedding.gate",
|
||||
r"vision_model.vision_encoder.(?=\w)": r"vision_model.",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
CONTEXT_LENGTH = 131072
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
input_tensor = input_tensor.reshape(dim1, dim2)
|
||||
input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def pre_compute_positional_embedding(embedding):
|
||||
"""
|
||||
Instead of iterating of the batch of images, and the ratios inside, we pre-compute the
|
||||
positional embeddings depending on the aspect ratio id. This is done to support `torch.compile`
|
||||
and efficient inference / training with different aspect ratios.
|
||||
"""
|
||||
max_num_tiles, *shapes = embedding.shape
|
||||
hidden_size = shapes[-1]
|
||||
supported_aspect_ratios = get_all_supported_aspect_ratios(max_num_tiles)
|
||||
max_aspect_ratio_id = len(supported_aspect_ratios) # we keep 0 index for padding
|
||||
# tile embedding does not have patches
|
||||
num_patches = 1 if len(shapes) == 2 else shapes[1]
|
||||
precomputed_embeddings = torch.zeros(
|
||||
max_aspect_ratio_id + 1,
|
||||
max_num_tiles,
|
||||
num_patches,
|
||||
hidden_size,
|
||||
device=embedding.device,
|
||||
dtype=embedding.dtype,
|
||||
)
|
||||
|
||||
for i, (height, width) in enumerate(supported_aspect_ratios):
|
||||
aspect_ratio_id = i + 1 # we keep 0 index for padding
|
||||
current_embedding = embedding[:height, :width].reshape(height * width, num_patches, hidden_size)
|
||||
precomputed_embeddings[aspect_ratio_id, : height * width] = current_embedding
|
||||
precomputed_embeddings = precomputed_embeddings.flatten(1)
|
||||
return precomputed_embeddings
|
||||
|
||||
|
||||
def is_param_different_across_shards(key):
|
||||
"""
|
||||
Return `True` if the parameter is different across checkpoint shards
|
||||
and needs to be concatenated.
|
||||
"""
|
||||
patterns = [r"vision_model.patch_embedding.weight",r"vision_model.(transformer|global_transformer).layers.(\d+).self_attn.(q|k|v|o)_proj.weight",r"vision_model.(transformer|global_transformer).layers.(\d+).mlp.fc1.(weight|bias)",r"vision_model.(transformer|global_transformer).layers.(\d+).mlp.fc2.weight", r"multi_modal_projector.(weight|bias)",r"language_model.model.embed_tokens.weight",r"language_model.lm_head.weight",r"language_model.model.layers.(\d+).self_attn.(q|k|v|o)_proj.weight",r"language_model.model.layers.(\d+).cross_attn.(q|k|v|o)_proj.weight",r"language_model.model.layers.(\d+).mlp.(up|down|gate)_proj.weight",r"language_model.model.learnable_embedding.weight"] # fmt: skip
|
||||
return any(re.search(pattern, key) for pattern in patterns)
|
||||
|
||||
|
||||
def get_concat_dim(key):
|
||||
"""
|
||||
Return the dimension to concatenate the weights on.
|
||||
"""
|
||||
concat_dim_1 = [r"vision_model.(transformer|global_transformer).layers.(\d+).mlp.fc2.weight",r"vision_model.(transformer|global_transformer).layers.(\d+).self_attn.o_proj.weight",r"language_model.model.layers.(\d+).cross_attn.o_proj.weight",r"language_model.model.layers.(\d+).self_attn.o_proj.weight",r"language_model.model.layers.(\d+).mlp.down_proj.weight"] # fmt: off
|
||||
if any(re.search(pattern, key) for pattern in concat_dim_1):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3):
|
||||
hidden_dim = 4 * int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
def interpolate_positional_embedding(
|
||||
embeddings: torch.Tensor, vision_tile_size: int, vision_patch_size: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position embeddings, to be able to use the model on higher resolution
|
||||
images.
|
||||
"""
|
||||
cls_embedding, positional_embedding = embeddings[:1], embeddings[1:]
|
||||
total_num_patches, dim = positional_embedding.shape
|
||||
|
||||
# compute current and target number of patches for height and width
|
||||
num_patches = int(round(total_num_patches**0.5))
|
||||
new_num_patches = vision_tile_size // vision_patch_size
|
||||
|
||||
# Check if the number of patches is already the desired size
|
||||
if num_patches == new_num_patches:
|
||||
return embeddings
|
||||
|
||||
positional_embedding = positional_embedding.transpose(0, 1)
|
||||
positional_embedding = positional_embedding.reshape(1, dim, num_patches, num_patches)
|
||||
positional_embedding = F.interpolate(
|
||||
positional_embedding,
|
||||
size=(new_num_patches, new_num_patches),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
positional_embedding = positional_embedding.reshape(dim, -1).transpose(0, 1)
|
||||
|
||||
embeddings = torch.cat([cls_embedding, positional_embedding], dim=0)
|
||||
return embeddings
|
||||
|
||||
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
num_shards,
|
||||
safe_serialization=True,
|
||||
instruct=False,
|
||||
):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
with open(os.path.join(input_base_path, "params.json"), "r") as f:
|
||||
params = json.load(f)
|
||||
|
||||
params = params.get("model", params)
|
||||
torch_dtype = "bfloat16"
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Text model params and config
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# params from config
|
||||
text_vocab_size = params["vocab_size"]
|
||||
text_num_layers = params["n_layers"]
|
||||
text_dim = params["dim"]
|
||||
text_num_heads = params["n_heads"]
|
||||
text_rms_norm_eps = params["norm_eps"]
|
||||
text_rope_theta = params["rope_theta"]
|
||||
cross_attention_num_layers = params["vision_num_cross_attention_layers"]
|
||||
|
||||
# some constans from original code
|
||||
rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": 8.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
}
|
||||
max_position_embeddings = CONTEXT_LENGTH
|
||||
|
||||
# compute additional params for weight conversion
|
||||
text_num_heads_per_shard = text_num_heads // num_shards
|
||||
text_dim_per_head = text_dim // text_num_heads
|
||||
text_intermediate_size = compute_intermediate_size(text_dim, multiple_of=params["multiple_of"])
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
text_num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
text_num_key_value_heads_per_shard = text_num_key_value_heads // num_shards
|
||||
text_key_value_dim = text_dim_per_head * text_num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
text_num_key_value_heads = text_num_heads
|
||||
text_num_key_value_heads_per_shard = text_num_heads_per_shard
|
||||
text_key_value_dim = text_dim
|
||||
|
||||
# cross-attention layers: 20 for 90B, 8 for 11B
|
||||
cross_attention_frequency = math.ceil(text_num_layers / cross_attention_num_layers)
|
||||
text_num_total_layers = text_num_layers + cross_attention_num_layers
|
||||
cross_attention_layers_shift = list(
|
||||
range(cross_attention_frequency - 1, text_num_total_layers, cross_attention_frequency + 1)
|
||||
)
|
||||
self_attention_layers_shift = [k for k in range(text_num_total_layers) if k not in cross_attention_layers_shift]
|
||||
|
||||
bos_token_id = 128000
|
||||
eos_token_id = [128001, 128008, 128009] if instruct else 128001
|
||||
pad_token_id = 128004
|
||||
|
||||
text_config = MllamaTextConfig(
|
||||
num_attention_heads=text_num_heads,
|
||||
vocab_size=text_vocab_size,
|
||||
hidden_size=text_dim,
|
||||
rms_norm_eps=text_rms_norm_eps,
|
||||
rope_theta=text_rope_theta,
|
||||
num_hidden_layers=text_num_total_layers,
|
||||
cross_attention_layers=cross_attention_layers_shift,
|
||||
intermediate_size=text_intermediate_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling=rope_scaling,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
tie_word_embeddings=False, # Constant set to False
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Vision model params and config
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# params from config
|
||||
vision_tile_size = params["vision_chunk_size"]
|
||||
vision_max_num_tiles = params["vision_max_num_chunks"]
|
||||
|
||||
# some constants from original code
|
||||
vision_patch_size = 14
|
||||
vision_num_channels = 3
|
||||
vision_num_layers = 32
|
||||
vision_num_layers_global = 8
|
||||
vision_dim = 1280
|
||||
vision_num_heads = 16
|
||||
vision_intermediate_layers_indices = [3, 7, 15, 23, 30]
|
||||
|
||||
# compute additional params for weight conversion
|
||||
vision_dim_per_head = vision_dim // vision_num_heads
|
||||
vision_num_heads_per_shard = vision_num_heads // num_shards
|
||||
vision_intermediate_size = vision_dim * 4
|
||||
vision_supported_aspect_ratios = get_all_supported_aspect_ratios(vision_max_num_tiles)
|
||||
|
||||
vision_config = MllamaVisionConfig(
|
||||
hidden_size=vision_dim,
|
||||
patch_size=vision_patch_size,
|
||||
num_channels=vision_num_channels,
|
||||
intermediate_size=vision_intermediate_size,
|
||||
num_hidden_layers=vision_num_layers,
|
||||
num_attention_heads=vision_num_heads,
|
||||
num_global_layers=vision_num_layers_global,
|
||||
intermediate_layers_indices=vision_intermediate_layers_indices,
|
||||
image_size=vision_tile_size,
|
||||
max_num_tiles=vision_max_num_tiles,
|
||||
supported_aspect_ratios=vision_supported_aspect_ratios,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# save config
|
||||
config = MllamaConfig(vision_config=vision_config, text_config=text_config, torch_dtype=torch_dtype)
|
||||
config.architectures = ["MllamaForConditionalGeneration"]
|
||||
config.save_pretrained(model_path)
|
||||
print("Model config saved successfully...")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Convert weights
|
||||
# ------------------------------------------------------------
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
|
||||
if num_shards == 1:
|
||||
loaded = [torch.load(os.path.join(input_base_path, "consolidated.pth"), map_location="cpu", mmap=True)]
|
||||
else:
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", mmap=True)
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
print("Converting model...")
|
||||
all_keys = list(loaded[0].keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||
|
||||
state_dict = {}
|
||||
for key in all_keys:
|
||||
new_key = new_keys[key]
|
||||
|
||||
# In the original model, self-attention layers and cross-attention layers are different lists of layers.
|
||||
# In the converted model, they are merged into one list with corresponding index shift to preserve the order.
|
||||
if ("cross_attention" in key or "text_model.layers" in key) and "language_model" in new_key:
|
||||
shift = cross_attention_layers_shift if "cross_attention" in key else self_attention_layers_shift
|
||||
new_key = re.sub(r"layers.(\d+).", lambda _match: f"layers.{shift[int(_match.groups()[0])]}.", new_key)
|
||||
|
||||
current_parameter = [chunk.pop(key).contiguous().clone() for chunk in loaded]
|
||||
if not is_param_different_across_shards(new_key):
|
||||
current_parameter = current_parameter[0]
|
||||
|
||||
concat_dim = get_concat_dim(new_key)
|
||||
|
||||
# Post-process the current_parameter.
|
||||
if re.search("(k|v|q)_proj.weight", new_key) and "language_model" in new_key:
|
||||
if "q_proj" in new_key:
|
||||
param_num_heads = text_num_heads
|
||||
param_num_head_per_shard = text_num_heads_per_shard
|
||||
param_dim = text_dim
|
||||
else:
|
||||
param_num_heads = text_num_key_value_heads
|
||||
param_num_head_per_shard = text_num_key_value_heads_per_shard
|
||||
param_dim = text_key_value_dim
|
||||
shards = [param.view(param_num_head_per_shard, text_dim_per_head, text_dim) for param in current_parameter]
|
||||
current_parameter = torch.cat(shards, dim=concat_dim)
|
||||
if "cross_attn" not in new_key and "v_proj.weight" not in new_key:
|
||||
current_parameter = permute_for_rope(current_parameter, param_num_heads, param_dim, text_dim)
|
||||
state_dict[new_key] = current_parameter.reshape(param_num_heads * text_dim_per_head, text_dim)
|
||||
|
||||
elif "vision_model" in new_key and re.search("(k|v|q)_proj", new_key):
|
||||
shards = [
|
||||
param.view(vision_num_heads_per_shard, vision_dim_per_head, vision_dim) for param in current_parameter
|
||||
]
|
||||
param = torch.cat(shards, dim=concat_dim)
|
||||
state_dict[new_key] = param.reshape(vision_num_heads * vision_dim_per_head, vision_dim)
|
||||
|
||||
elif new_key == "vision_model.patch_embedding.weight":
|
||||
current_parameter = torch.cat(current_parameter, dim=concat_dim)
|
||||
state_dict[new_key] = current_parameter.reshape(
|
||||
-1, vision_num_channels, vision_patch_size, vision_patch_size
|
||||
)
|
||||
|
||||
elif new_key.endswith("gate"):
|
||||
state_dict[new_key] = current_parameter[0].view(1)
|
||||
|
||||
elif "vision_model.gated_positional_embedding.embedding" in new_key:
|
||||
current_parameter = interpolate_positional_embedding(
|
||||
current_parameter, vision_tile_size, vision_patch_size
|
||||
)
|
||||
state_dict[new_key] = current_parameter
|
||||
|
||||
elif "vision_model.gated_positional_embedding.tile_embedding.weight" in new_key:
|
||||
current_parameter = current_parameter.permute(2, 0, 1, 3).flatten(1)
|
||||
current_parameter = interpolate_positional_embedding(
|
||||
current_parameter, vision_tile_size, vision_patch_size
|
||||
)
|
||||
current_parameter = current_parameter.reshape(
|
||||
-1, vision_max_num_tiles, vision_max_num_tiles, vision_dim
|
||||
).permute(1, 2, 0, 3)
|
||||
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)
|
||||
|
||||
elif "tile_positional_embedding.embedding" in new_key:
|
||||
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)
|
||||
|
||||
elif new_key != "":
|
||||
if isinstance(current_parameter, list):
|
||||
current_parameter = torch.cat(current_parameter, dim=concat_dim)
|
||||
state_dict[new_key] = current_parameter
|
||||
|
||||
state_dict["language_model.model.embed_tokens.weight"] = torch.cat(
|
||||
[
|
||||
state_dict["language_model.model.embed_tokens.weight"],
|
||||
state_dict.pop("language_model.model.learnable_embedding.weight"),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a Mllama model.")
|
||||
with torch.device("meta"):
|
||||
model = MllamaForConditionalGeneration(config)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
MllamaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
# generation config
|
||||
if instruct:
|
||||
print("Saving generation config...")
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
)
|
||||
generation_config.save_pretrained(model_path)
|
||||
|
||||
|
||||
class MllamaConverter(TikTokenConverter):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
special_tokens: List[str],
|
||||
pattern: str,
|
||||
model_max_length: int,
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(vocab_file, pattern=pattern)
|
||||
self.additional_special_tokens = special_tokens
|
||||
tokenizer = self.converted()
|
||||
if chat_template is not None:
|
||||
kwargs["chat_template"] = chat_template
|
||||
self.tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
model_max_length=model_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False):
|
||||
model_max_length = CONTEXT_LENGTH
|
||||
pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605
|
||||
|
||||
# Special tokens
|
||||
num_reserved_special_tokens = 256
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|step_id|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
]
|
||||
special_tokens += [
|
||||
f"<|reserved_special_token_{i + 2}|>" for i in range(num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
# original tokenizer has <|image|> with 128011 token_id,
|
||||
# however, later in the code it is replaced with 128256 token_id
|
||||
special_tokens.append("<|image|>")
|
||||
|
||||
# Chat template
|
||||
chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if loop.index0 == 0 %}"
|
||||
"{{ bos_token }}"
|
||||
"{% endif %}"
|
||||
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}"
|
||||
"{% if message['content'] is string %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% else %}"
|
||||
"{% for content in message['content'] %}"
|
||||
"{% if content['type'] == 'image' %}"
|
||||
"{{ '<|image|>' }}"
|
||||
"{% elif content['type'] == 'text' %}"
|
||||
"{{ content['text'] }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% endif %}"
|
||||
"{{ '<|eot_id|>' }}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
|
||||
converter = MllamaConverter(
|
||||
vocab_file=tokenizer_path,
|
||||
pattern=pattern,
|
||||
special_tokens=special_tokens,
|
||||
model_max_length=model_max_length,
|
||||
chat_template=chat_template if instruct else None,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>",
|
||||
pad_token="<|finetune_right_pad_id|>",
|
||||
)
|
||||
tokenizer = converter.tokenizer
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
if instruct:
|
||||
print("Saving chat template...")
|
||||
chat_template_path = os.path.join(save_dir, "chat_template.json")
|
||||
with open(chat_template_path, "w") as f:
|
||||
json.dump({"chat_template": chat_template}, f, indent=2)
|
||||
|
||||
|
||||
def write_image_processor(config_path: str, save_dir: str):
|
||||
with open(config_path, "r") as f:
|
||||
params = json.load(f)
|
||||
|
||||
tile_size = params["vision_chunk_size"]
|
||||
max_image_tiles = params["vision_max_num_chunks"]
|
||||
|
||||
image_processor = MllamaImageProcessor(
|
||||
do_resize=True,
|
||||
size={"height": tile_size, "width": tile_size},
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_pad=True,
|
||||
max_image_tiles=max_image_tiles,
|
||||
)
|
||||
|
||||
image_processor.save_pretrained(save_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
default="Llama-3.2-11B-Vision/original",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="Llama-3.2-11B-Vision",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--special_tokens",
|
||||
default=None,
|
||||
type=List[str],
|
||||
help="The list of special tokens that should be added to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruct",
|
||||
action="store_true",
|
||||
help="Whether the model is an instruct model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
num_shards=args.num_shards,
|
||||
instruct=args.instruct,
|
||||
)
|
||||
|
||||
write_tokenizer(
|
||||
tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
|
||||
save_dir=args.output_dir,
|
||||
instruct=args.instruct,
|
||||
)
|
||||
|
||||
write_image_processor(
|
||||
config_path=os.path.join(args.input_dir, "params.json"),
|
||||
save_dir=args.output_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
862
src/transformers/models/mllama/image_processing_mllama.py
Normal file
862
src/transformers/models/mllama/image_processing_mllama.py
Normal file
@ -0,0 +1,862 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import (
|
||||
PaddingMode,
|
||||
get_image_size,
|
||||
pad,
|
||||
resize,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_valid_image,
|
||||
is_vision_available,
|
||||
to_numpy_array,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_all_supported_aspect_ratios(max_image_tiles: int) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Computes all allowed aspect ratios for a given maximum number of input tiles.
|
||||
|
||||
This function calculates all possible arrangements of tiles that can be formed
|
||||
within the constraint of the maximum number of tiles. Each arrangement is
|
||||
represented by its aspect ratio (width/height) and the corresponding tile configuration.
|
||||
|
||||
Args:
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles allowed.
|
||||
|
||||
Returns:
|
||||
`List[Tuple[int, int]]`: A list of tuples, each tuple representing a valid (width, height)
|
||||
configuration in terms of number of tiles.
|
||||
|
||||
Example:
|
||||
>>> get_all_supported_aspect_ratios(4)
|
||||
[(1, 1), (1, 2), (1, 3), (1, 4), (2, 1), (2, 2), (3, 1), (4, 1)]
|
||||
|
||||
"""
|
||||
aspect_ratios = []
|
||||
for width in range(1, max_image_tiles + 1):
|
||||
for height in range(1, max_image_tiles + 1):
|
||||
if width * height <= max_image_tiles:
|
||||
aspect_ratios.append((width, height))
|
||||
return aspect_ratios
|
||||
|
||||
|
||||
def get_image_size_fit_to_canvas(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
canvas_height: int,
|
||||
canvas_width: int,
|
||||
tile_size: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculates the new size of an image to fit within a canvas while maintaining aspect ratio.
|
||||
|
||||
This function calculates the optimal size for an image to fit within a canvas defined by
|
||||
canvas_height and canvas_width, while ensuring that the image dimensions are not smaller than
|
||||
tile_size. If the image is larger than the canvas, the returned size will fit within the canvas.
|
||||
If the image already fits within the canvas, the size remains unchanged.
|
||||
The aspect ratio of the original image is preserved.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
The height of the original image.
|
||||
image_width (`int`):
|
||||
The width of the original image.
|
||||
canvas_height (`int`):
|
||||
The height of the canvas.
|
||||
canvas_width (`int`):
|
||||
The width of the canvas.
|
||||
tile_size (`int`):
|
||||
The tile size.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: A tuple containing the new height and width of the image.
|
||||
|
||||
"""
|
||||
# Set target image size in between `tile_size` and canvas_size
|
||||
target_width = np.clip(image_width, tile_size, canvas_width)
|
||||
target_height = np.clip(image_height, tile_size, canvas_height)
|
||||
|
||||
scale_h = target_height / image_height
|
||||
scale_w = target_width / image_width
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(image_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(image_width * scale_h), target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_optimal_tiled_canvas(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
max_image_tiles: int,
|
||||
tile_size: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas based on image and tile size and maximum number of tiles.
|
||||
|
||||
First, calculates possible resolutions based on the maximum number of tiles and tile size.
|
||||
For example for max_image_tiles=2, tile_size=100, possible tile arrangements are:
|
||||
[(1, 1), (1, 2), (2, 1)] and corresponding canvas sizes are:
|
||||
[(100, 100), (100, 200), (200, 100)]
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
The height of the image.
|
||||
image_width (`int`):
|
||||
The width of the image.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles any image can be split into.
|
||||
tile_size (`int`):
|
||||
The tile size.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The best canvas resolution [height, width] for the given image.
|
||||
"""
|
||||
possible_tile_arrangements = get_all_supported_aspect_ratios(max_image_tiles)
|
||||
possible_canvas_sizes = np.array(possible_tile_arrangements) * tile_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_heights, target_widths = np.array(possible_canvas_sizes).T
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_h = target_heights / image_height
|
||||
scale_w = target_widths / image_width
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = np.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
selected_scale = np.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = np.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_canvas_sizes[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = np.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return optimal_canvas
|
||||
|
||||
|
||||
def split_to_tiles(image: np.ndarray, num_tiles_height: int, num_tiles_width: int) -> np.ndarray:
|
||||
"""
|
||||
Split an image into a specified number of tiles along its width and height dimensions.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Input image with shape (num_channels, height, width).
|
||||
num_tiles_height (`int`):
|
||||
Number of tiles to split the image into along its height.
|
||||
num_tiles_width (`int`):
|
||||
Number of tiles to split the image into along its width.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
Array of image tiles with shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width).
|
||||
"""
|
||||
num_channels, height, width = image.shape
|
||||
tile_height = height // num_tiles_height
|
||||
tile_width = width // num_tiles_width
|
||||
|
||||
image = image.reshape(num_channels, num_tiles_height, tile_height, num_tiles_width, tile_width)
|
||||
|
||||
# Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
|
||||
image = image.transpose(1, 3, 0, 2, 4)
|
||||
|
||||
# Reshape into the desired output shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
|
||||
image = image.reshape(num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
|
||||
|
||||
return np.ascontiguousarray(image)
|
||||
|
||||
|
||||
def build_aspect_ratio_mask(aspect_ratios: List[List[Tuple[int, int]]], max_image_tiles: int) -> np.ndarray:
|
||||
"""
|
||||
Builds a mask for the aspect ratios of the images.
|
||||
|
||||
Args:
|
||||
aspect_ratios (`List[List[Tuple[int, int]]]`):
|
||||
A list of lists containing aspect ratios for each image in the batch.
|
||||
Each aspect ratio is represented as a tuple of (width, height) in terms of number of tiles.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles any image can be split into.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: A 3D numpy array of shape (batch_size, max_num_images, max_image_tiles).
|
||||
The mask contains 1s for valid tiles and 0s for padding.
|
||||
"""
|
||||
batch_size = len(aspect_ratios)
|
||||
max_num_images = max([len(row) for row in aspect_ratios])
|
||||
|
||||
aspect_ratio_mask = np.zeros((batch_size, max_num_images, max_image_tiles), dtype=np.int64)
|
||||
|
||||
# Set the first tile to 1 for all aspect ratios
|
||||
# because in original implementation aspect ratios are padded with (1, 1),
|
||||
# but original code examples are not built to handle batches, so we might remove it later
|
||||
aspect_ratio_mask[:, :, 0] = 1
|
||||
|
||||
# Set the aspect ratio mask for the rest of the tiles
|
||||
for i, sample_aspect_ratios in enumerate(aspect_ratios):
|
||||
for j, (num_tiles_w, num_tiles_h) in enumerate(sample_aspect_ratios):
|
||||
aspect_ratio_mask[i, j, : num_tiles_w * num_tiles_h] = 1
|
||||
|
||||
return aspect_ratio_mask
|
||||
|
||||
|
||||
def pack_images(
|
||||
batch_images: List[List[np.ndarray]],
|
||||
max_image_tiles: int,
|
||||
) -> Tuple[np.ndarray, List[List[int]]]:
|
||||
"""
|
||||
Stack a list of lists of images with variable lengths into a numpy array, applying zero padding as needed.
|
||||
Each list in the input represents a batch sample, and each image within a list is expected to be
|
||||
pre-split into tiles. The resulting array will have a shape of
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width).
|
||||
|
||||
Args:
|
||||
batch_images (`List[List[np.ndarray]]`):
|
||||
A list of lists of image tiles. Each inner list represents
|
||||
a batch sample containing multiple images, where each image is pre-split into tiles.
|
||||
The shape of each tile array is (num_tiles, channels, tile_height, tile_width).
|
||||
max_image_tiles (int):
|
||||
The maximum number of tiles any image was potantially split.
|
||||
|
||||
Returns:
|
||||
`Tuple[np.ndarray, List[List[int]]]`: A tuple containing:
|
||||
- stacked_images (`np.ndarray`):
|
||||
A numpy array of stacked images with shape
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width).
|
||||
- all_num_tiles (`List[List[int]]`):
|
||||
A list of lists containing the number of tiles
|
||||
for each image in each batch sample.
|
||||
"""
|
||||
|
||||
# Determine output shape
|
||||
batch_size = len(batch_images)
|
||||
max_num_images = max([len(images) for images in batch_images])
|
||||
shapes = [image.shape for images in batch_images for image in images]
|
||||
_, channels, tile_height, tile_width = shapes[0]
|
||||
|
||||
# Initialize the stacked images array with zeros
|
||||
stacked_images = np.zeros(
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Fill the stacked images array with the tiled images from the batch
|
||||
all_num_tiles = []
|
||||
for i, images in enumerate(batch_images):
|
||||
num_sample_tiles = []
|
||||
for j, image in enumerate(images):
|
||||
num_tiles = image.shape[0]
|
||||
stacked_images[i, j, :num_tiles] = image
|
||||
num_sample_tiles.append(num_tiles)
|
||||
all_num_tiles.append(num_sample_tiles)
|
||||
|
||||
return stacked_images, all_num_tiles
|
||||
|
||||
|
||||
def pack_aspect_ratios(aspect_ratios: List[List[Tuple[int, int]]], pad_value: int = 1) -> np.ndarray:
|
||||
"""
|
||||
Stack a list of aspect ratios into a numpy array.
|
||||
|
||||
Args:
|
||||
aspect_ratios (`List[List[Tuple[int, int]]]`):
|
||||
A list of aspect ratios.
|
||||
pad_value (`int`, *optional*, defaults to 1):
|
||||
The value to pad the aspect ratios with.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The aspect ratios stacked into a numpy array with shape (batch_size, max_num_images, 2).
|
||||
"""
|
||||
batch_size = len(aspect_ratios)
|
||||
max_num_images = max([len(row) for row in aspect_ratios])
|
||||
|
||||
aspect_ratios_stacked = np.full((batch_size, max_num_images, 2), pad_value, dtype=np.int64)
|
||||
for i, row in enumerate(aspect_ratios):
|
||||
if len(row) > 0:
|
||||
aspect_ratios_stacked[i, : len(row)] = np.array(row)
|
||||
return aspect_ratios_stacked
|
||||
|
||||
|
||||
def convert_aspect_ratios_to_ids(aspect_ratios: List[List[Tuple[int, int]]], max_image_tiles: int) -> np.ndarray:
|
||||
"""
|
||||
Convert aspect ratio tuples to unique ids.
|
||||
|
||||
For batch padding we use 0, because there might be different number of images in each batch.
|
||||
The aspect ratio ids start from 1, with 1 corresponding to the first supported aspect ratio.
|
||||
|
||||
Args:
|
||||
aspect_ratios (`List[List[Tuple[int, int]]]`):
|
||||
A list of aspect ratios for each image in the batch.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles any image can be split into.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The aspect ratios ids as a numpy array with shape (batch_size, max_num_images).
|
||||
Each id corresponds to the index of the aspect ratio in the list of supported aspect ratios,
|
||||
offset by 1 (so 0 can be used for padding).
|
||||
"""
|
||||
|
||||
batch_size = len(aspect_ratios)
|
||||
max_num_images = max([len(row) for row in aspect_ratios])
|
||||
supported_aspect_ratios = get_all_supported_aspect_ratios(max_image_tiles)
|
||||
|
||||
aspect_ratios_ids = np.zeros((batch_size, max_num_images), dtype=np.int64)
|
||||
for i, sample_aspect_ratios in enumerate(aspect_ratios):
|
||||
for j, (num_tiles_h, num_tiles_w) in enumerate(sample_aspect_ratios):
|
||||
aspect_ratios_ids[i, j] = supported_aspect_ratios.index((num_tiles_h, num_tiles_w)) + 1
|
||||
return aspect_ratios_ids
|
||||
|
||||
|
||||
def to_channel_dimension_format(
|
||||
image: np.ndarray,
|
||||
channel_dim: Union[ChannelDimension, str],
|
||||
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`.
|
||||
|
||||
Args:
|
||||
image (`numpy.ndarray`):
|
||||
The image to have its channel dimension set.
|
||||
channel_dim (`ChannelDimension`):
|
||||
The channel dimension format to use.
|
||||
input_channel_dim (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The image with the channel dimension set to `channel_dim`.
|
||||
"""
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
if input_channel_dim is None:
|
||||
input_channel_dim = infer_channel_dimension_format(image)
|
||||
|
||||
target_channel_dim = ChannelDimension(channel_dim)
|
||||
if input_channel_dim == target_channel_dim:
|
||||
return image
|
||||
|
||||
if target_channel_dim == ChannelDimension.FIRST:
|
||||
image = image.transpose((2, 0, 1))
|
||||
elif target_channel_dim == ChannelDimension.LAST:
|
||||
image = image.transpose((1, 2, 0))
|
||||
else:
|
||||
raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# Copied from transformers.models.idefics2.image_processing_idefics2.convert_to_rgb
|
||||
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||
"""
|
||||
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||
as is.
|
||||
Args:
|
||||
image (Image):
|
||||
The image to convert.
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
|
||||
# for transparent images. The call to `alpha_composite` handles this case
|
||||
if image.mode == "RGB":
|
||||
return image
|
||||
|
||||
image_rgba = image.convert("RGBA")
|
||||
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
||||
alpha_composite = Image.alpha_composite(background, image_rgba)
|
||||
alpha_composite = alpha_composite.convert("RGB")
|
||||
return alpha_composite
|
||||
|
||||
|
||||
# Modified from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
|
||||
def make_list_of_images(images: ImageInput) -> List[List[Optional[np.ndarray]]]:
|
||||
"""
|
||||
Convert a single image or a list of images to a list of numpy arrays.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
A single image or a list of images.
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays.
|
||||
"""
|
||||
# If it's a single image, convert it to a list of lists
|
||||
if is_valid_image(images):
|
||||
output_images = [[images]]
|
||||
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
||||
elif isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
|
||||
output_images = [images]
|
||||
# If it's a list of batches, it's already in the right format
|
||||
elif (
|
||||
isinstance(images, (list, tuple))
|
||||
and all(isinstance(images_i, (list, tuple)) for images_i in images)
|
||||
and any(is_valid_list_of_images(images_i) for images_i in images)
|
||||
):
|
||||
output_images = images
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
|
||||
)
|
||||
return output_images
|
||||
|
||||
|
||||
def is_valid_list_of_images(images: List):
|
||||
return images and all(is_valid_image(image) for image in images)
|
||||
|
||||
|
||||
def _validate_size(size: Dict[str, int]) -> None:
|
||||
if not ("height" in size and "width" in size):
|
||||
raise ValueError(f"Argument `size` must be a dictionary with keys 'height' and 'width'. Got: {size}")
|
||||
if size["height"] != size["width"]:
|
||||
raise ValueError(f"Argument `size` must have the same height and width, got {size}")
|
||||
|
||||
|
||||
def _validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles):
|
||||
if not do_pad:
|
||||
raise ValueError("MllamaImageProcessor doesn't support `do_pad=False` mode.")
|
||||
if not do_resize:
|
||||
raise ValueError("MllamaImageProcessor doesn't support `do_resize=False` mode.")
|
||||
if max_image_tiles is None or max_image_tiles <= 0:
|
||||
raise ValueError(f"MllamaImageProcessor `max_image_tiles` must be a positive integer, got {max_image_tiles}.")
|
||||
_validate_size(size)
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
Constructs a Mllama image processor.
|
||||
|
||||
Args:
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
|
||||
Only has an effect if the input image is in the PIL format.
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image tile. Should be a dictionary containing 'height' and 'width' keys, both with integer values.
|
||||
The height and width values should be equal.
|
||||
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to 0.0):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to pad the images to the largest height and width in the batch.
|
||||
max_image_tiles (`int`, *optional*, defaults to 4):
|
||||
The maximum number of tiles to split the image into.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "num_tiles", "aspect_ratio_ids", "aspect_ratio_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_convert_rgb: bool = True,
|
||||
do_resize: bool = True,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: bool = True,
|
||||
max_image_tiles: int = 4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_resize = do_resize
|
||||
self.size = size if size is not None else {"height": 224, "width": 224}
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.do_pad = do_pad
|
||||
self.max_image_tiles = max_image_tiles
|
||||
|
||||
_validate_mllama_preprocess_arguments(self.do_resize, self.size, self.do_pad, self.max_image_tiles)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
max_image_tiles: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
):
|
||||
"""
|
||||
Preprocess a batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
A list of images to preprocess.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image tile. Should be a dictionary containing 'height' and 'width' keys, both with integer values.
|
||||
The height and width values should be equal.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether or not to pad the images to the largest height and width in the batch.
|
||||
max_image_tiles (`int`, *optional*, defaults to `self.max_image_tiles`):
|
||||
The maximum number of tiles to split the image into.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
`BatchFeature` of the following structure:
|
||||
- **pixel_values** (`TensorType`): The preprocessed pixel values.
|
||||
- **aspect_ratio_ids** (`TensorType`): The aspect ratio ids of the images.
|
||||
- **num_tiles** (`List[List[int]]`): The number of tiles for each image in the batch.
|
||||
"""
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
max_image_tiles = max_image_tiles if max_image_tiles is not None else self.max_image_tiles
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
# extra validation
|
||||
_validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles)
|
||||
|
||||
images_list = make_list_of_images(images)
|
||||
|
||||
if self.do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
|
||||
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||
|
||||
batch_images = []
|
||||
batch_aspect_ratios = []
|
||||
|
||||
# iterate over batch samples
|
||||
for images in images_list:
|
||||
sample_images = []
|
||||
sample_aspect_ratios = []
|
||||
|
||||
# iterate over images in a batch sample
|
||||
for image in images:
|
||||
# convert images to channels first format for faster processing
|
||||
# LAST is slower for `pad` and not supported by `split_to_tiles`
|
||||
data_format = ChannelDimension.FIRST
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
# do_resize=False is not supported, validated
|
||||
image, aspect_ratio = self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
resample=resample,
|
||||
max_image_tiles=max_image_tiles,
|
||||
input_data_format=data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
# do_pad=False is not supported, validated
|
||||
image = self.pad(
|
||||
image=image,
|
||||
size=size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
input_data_format=data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(
|
||||
image=image,
|
||||
scale=rescale_factor,
|
||||
input_data_format=input_data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
num_tiles_height, num_tiles_width = aspect_ratio
|
||||
image = split_to_tiles(image, num_tiles_height, num_tiles_width)
|
||||
|
||||
sample_images.append(image)
|
||||
sample_aspect_ratios.append((num_tiles_height, num_tiles_width))
|
||||
|
||||
batch_images.append(sample_images)
|
||||
batch_aspect_ratios.append(sample_aspect_ratios)
|
||||
|
||||
images, num_tiles = pack_images(batch_images, max_image_tiles)
|
||||
|
||||
aspect_ratio_ids = convert_aspect_ratios_to_ids(batch_aspect_ratios, max_image_tiles=max_image_tiles)
|
||||
aspect_ratio_mask = build_aspect_ratio_mask(batch_aspect_ratios, max_image_tiles=max_image_tiles)
|
||||
|
||||
# images (np.ndarray) with shape (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||
# aspect_ratio_ids (np.ndarray) with shape (batch_size, max_num_images) - aspect ratio ids for each image, padded to max_num_images with 0
|
||||
# num_tiles (List[List[int]]) with (batch_size, num_images_in_batch) - real number of tiles for each image, not padded
|
||||
# aspect_ratio_mask (np.ndarray) with shape (batch_size, max_num_images, max_image_tiles) - number of tiles for each image, padded to max_num_images with 0
|
||||
encoded_inputs = BatchFeature(
|
||||
data={
|
||||
"pixel_values": images,
|
||||
"aspect_ratio_ids": aspect_ratio_ids,
|
||||
"aspect_ratio_mask": aspect_ratio_mask,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
encoded_inputs["num_tiles"] = num_tiles
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def pad(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
aspect_ratio: Tuple[int, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image to the `size` x `aspect_ratio`. For example, if size is {height: 224, width: 224} and aspect ratio is
|
||||
(1, 2), the image will be padded to 224x448.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
aspect_ratio (`Tuple[int, int]`):
|
||||
The aspect ratio of the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
"""
|
||||
|
||||
_validate_size(size)
|
||||
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
num_tiles_height, num_tiles_width = aspect_ratio
|
||||
padded_height = num_tiles_height * size["height"]
|
||||
padded_width = num_tiles_width * size["width"]
|
||||
pad_size = ((0, padded_height - image_height), (0, padded_width - image_width))
|
||||
|
||||
image = pad(
|
||||
image,
|
||||
pad_size,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=0,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
max_image_tiles: int,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Union[np.ndarray, Tuple[int, int]]:
|
||||
"""
|
||||
Resizes an image to fit within a tiled canvas while maintaining its aspect ratio.
|
||||
The optimal canvas size is calculated based on the maximum number of tiles and the tile size.
|
||||
|
||||
The function first determines the best tile arrangement for the image, then resizes the image
|
||||
to fit within this canvas. The resized image and the number of tiles along the height and width
|
||||
dimensions are returned.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles to split the image into.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use when resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, Tuple[int, int]]`: The resized image and a tuple containing the number of tiles
|
||||
along the height and width dimensions.
|
||||
"""
|
||||
|
||||
_validate_size(size)
|
||||
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
tile_size = size["height"]
|
||||
|
||||
canvas_height, canvas_width = get_optimal_tiled_canvas(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
max_image_tiles=max_image_tiles,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
num_tiles_height = canvas_height // tile_size
|
||||
num_tiles_width = canvas_width // tile_size
|
||||
|
||||
new_height, new_width = get_image_size_fit_to_canvas(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
canvas_height=canvas_height,
|
||||
canvas_width=canvas_width,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
|
||||
image = resize(
|
||||
image,
|
||||
(new_height, new_width),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
return image, (num_tiles_height, num_tiles_width)
|
2288
src/transformers/models/mllama/modeling_mllama.py
Normal file
2288
src/transformers/models/mllama/modeling_mllama.py
Normal file
File diff suppressed because it is too large
Load Diff
358
src/transformers/models/mllama/processing_mllama.py
Normal file
358
src/transformers/models/mllama/processing_mllama.py
Normal file
@ -0,0 +1,358 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Mllama.
|
||||
"""
|
||||
|
||||
from statistics import mean
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import (
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from ...tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
)
|
||||
|
||||
# TODO: Can we do it that way or its better include as "Copied from ..."
|
||||
from .image_processing_mllama import make_list_of_images
|
||||
|
||||
|
||||
class MllamaImagesKwargs(ImagesKwargs, total=False):
|
||||
max_image_tiles: Optional[int]
|
||||
|
||||
|
||||
class MllamaProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: MllamaImagesKwargs
|
||||
|
||||
_defaults = {
|
||||
"image_kwargs": {
|
||||
"max_image_tiles": 4,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_cross_attention_token_mask(input_ids: List[int], image_token_id: int) -> List[List[int]]:
|
||||
"""
|
||||
Generate a cross-attention token mask for image tokens in the input sequence.
|
||||
|
||||
This function identifies the positions of image tokens in the input sequence and creates
|
||||
a mask that defines which subsequent tokens each image token should attend to.
|
||||
|
||||
Args:
|
||||
input_ids (List[int]): A list of token ids representing the input sequence.
|
||||
image_token_id (int): The id of the token used to represent images in the sequence.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: A list of [start, end] pairs, where each pair represents the range
|
||||
of tokens an image token should attend to.
|
||||
|
||||
Notes:
|
||||
- If no image tokens are present, an empty list is returned.
|
||||
- For a single image token, it attends to all subsequent tokens until the end of the sequence.
|
||||
- For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
|
||||
- Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
|
||||
"""
|
||||
|
||||
image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
|
||||
|
||||
if len(image_token_locations) == 0:
|
||||
return []
|
||||
|
||||
# only one image present, unmask until end of sequence
|
||||
if len(image_token_locations) == 1:
|
||||
return [[image_token_locations[0], -1]]
|
||||
|
||||
vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
|
||||
|
||||
# last image will attend to all subsequent text
|
||||
vision_masks.append([image_token_locations[-1], len(input_ids)])
|
||||
|
||||
# if there are two or more consecutive vision tokens,
|
||||
# they should all attend to all subsequent
|
||||
# text present
|
||||
last_mask_end = vision_masks[-1][1]
|
||||
for vision_mask in vision_masks[::-1]:
|
||||
if vision_mask[0] == vision_mask[1] - 1:
|
||||
vision_mask[1] = last_mask_end
|
||||
last_mask_end = vision_mask[1]
|
||||
|
||||
return vision_masks
|
||||
|
||||
|
||||
def convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask: List[List[List[int]]],
|
||||
num_tiles: List[List[int]],
|
||||
max_num_tiles: int,
|
||||
length: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert the cross attention mask indices to a cross attention mask 4D array.
|
||||
|
||||
This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
|
||||
The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
|
||||
|
||||
Args:
|
||||
cross_attention_token_mask (List[List[List[int]]]): A nested list structure where:
|
||||
- The outer list represents the batch dimension.
|
||||
- The middle list represents different images within each batch item.
|
||||
- The inner list contains pairs of integers [start, end] representing token ranges for each image.
|
||||
num_tiles (List[List[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
|
||||
max_num_tiles (int): The maximum possible number of tiles.
|
||||
length (int): The total sequence length of the input.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
|
||||
The array contains `1` where attention is allowed and `0` where it is not.
|
||||
|
||||
Note:
|
||||
- Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
|
||||
"""
|
||||
|
||||
batch_size = len(cross_attention_token_mask)
|
||||
max_num_images = max([len(masks) for masks in cross_attention_token_mask])
|
||||
|
||||
cross_attention_mask = np.zeros(
|
||||
shape=(batch_size, length, max_num_images, max_num_tiles),
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
|
||||
for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
|
||||
if len(locations) == 2:
|
||||
start, end = locations
|
||||
end = min(end, length)
|
||||
if end == -1:
|
||||
end = length
|
||||
cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
|
||||
return cross_attention_mask
|
||||
|
||||
|
||||
def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
|
||||
"""
|
||||
Builds a string from the input prompt by adding `bos_token` if not already present.
|
||||
|
||||
Args:
|
||||
prompt (`str`):
|
||||
The input prompt string.
|
||||
bos_token (`str`):
|
||||
The beginning of sentence token to be added.
|
||||
image_token (`str`):
|
||||
The image token used to identify the start of an image sequence.
|
||||
|
||||
Returns:
|
||||
str: The modified prompt string with the `bos_token` added if necessary.
|
||||
|
||||
Examples:
|
||||
>>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
|
||||
'<begin_of_text>Hello world'
|
||||
|
||||
>>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
|
||||
'<|image|><begin_of_text>Hello world'
|
||||
|
||||
>>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
|
||||
'<begin_of_text>Hello world'
|
||||
"""
|
||||
|
||||
if bos_token in prompt:
|
||||
return prompt
|
||||
|
||||
num_image_tokens_on_start = 0
|
||||
while prompt.startswith(image_token):
|
||||
prompt = prompt[len(image_token) :]
|
||||
num_image_tokens_on_start += 1
|
||||
|
||||
return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
|
||||
|
||||
|
||||
class MllamaProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
|
||||
[`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
|
||||
tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
|
||||
information.
|
||||
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
|
||||
```python
|
||||
from transformers import MllamaProcessor
|
||||
from PIL import Image
|
||||
|
||||
processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
|
||||
|
||||
processor(
|
||||
images=your_pil_image,
|
||||
text=["<|image|>If I had to write a haiku for this one"],
|
||||
images_kwargs = {"size": {"height": 448, "width": 448}},
|
||||
text_kwargs = {"padding": "right"},
|
||||
common_kwargs = {"return_tensors": "pt"},
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
image_processor ([`MllamaImageProcessor`]):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "MllamaImageProcessor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
self.image_token = "<|image|>"
|
||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
self.python_token = "<|python_tag|>"
|
||||
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
||||
self.bos_token = tokenizer.bos_token
|
||||
self.chat_template = tokenizer.chat_template
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[ImageInput] = None,
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
**kwargs: Unpack[MllamaProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
|
||||
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` arguments to
|
||||
MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
|
||||
to the docstring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
|
||||
"""
|
||||
if text is None and images is None:
|
||||
raise ValueError("You must specify either text or images.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MllamaProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
images_kwargs = output_kwargs["images_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
data = {}
|
||||
if text is not None:
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
n_images_in_text = [t.count(self.image_token) for t in text]
|
||||
text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
|
||||
_ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
|
||||
encoding = self.tokenizer(text, **text_kwargs)
|
||||
data.update(encoding)
|
||||
|
||||
if images is not None:
|
||||
images = make_list_of_images(images)
|
||||
n_images_in_images = [len(sample) for sample in images]
|
||||
|
||||
if text is not None:
|
||||
if (
|
||||
not all(batch_img_per_prompt == n_images_in_images for batch_img_per_prompt in n_images_in_text)
|
||||
and len(text) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
f"The number of images in each batch {n_images_in_text} should be the same {n_images_in_images} should be the same. Yes, the model does not \
|
||||
support having a different number of images per batch."
|
||||
)
|
||||
if int(mean(n_images_in_text)) != int(mean(n_images_in_images)):
|
||||
raise ValueError(
|
||||
f"The number of images in the text ({n_images_in_text}) should be the same as in the number of provided images ({n_images_in_images}) \
|
||||
should be the same."
|
||||
)
|
||||
|
||||
image_features = self.image_processor(images, **images_kwargs)
|
||||
num_tiles = image_features.pop("num_tiles")
|
||||
data.update(image_features)
|
||||
|
||||
# Create cross attention mask
|
||||
if images is not None and text is not None:
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
|
||||
]
|
||||
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=self.image_processor.max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in encoding["input_ids"]),
|
||||
)
|
||||
data["cross_attention_mask"] = cross_attention_mask
|
||||
|
||||
return_tensors = common_kwargs.pop("return_tensors", None)
|
||||
batch_encoding = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
|
@ -5945,6 +5945,48 @@ class MixtralPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaTextModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MllamaVisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MobileBertForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -408,6 +408,13 @@ class MaskFormerImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class MllamaImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class MobileNetV1FeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
@ -490,7 +490,7 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
@ -631,7 +631,7 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
@ -983,7 +983,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.is_decoder = True
|
||||
|
||||
# test old generation output for backwards compatibility
|
||||
@ -1014,7 +1014,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -1054,7 +1054,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
|
||||
@ -1085,6 +1085,7 @@ class GenerationTesterMixin:
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -1172,7 +1173,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -1249,7 +1250,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -1362,7 +1363,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -1549,7 +1550,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
@ -1745,7 +1746,7 @@ class GenerationTesterMixin:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
@ -1845,12 +1846,13 @@ class GenerationTesterMixin:
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
set_seed(seed)
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
|
||||
else:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
past_key_values = cache_cls(num_hidden_layers)
|
||||
new_results = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1870,23 +1872,27 @@ class GenerationTesterMixin:
|
||||
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
||||
for layer_idx in range(len(legacy_cache)):
|
||||
for kv_idx in range(len(legacy_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
legacy_cache[layer_idx][kv_idx],
|
||||
new_cache_converted[layer_idx][kv_idx],
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
if legacy_cache[layer_idx][kv_idx] != []:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
legacy_cache[layer_idx][kv_idx],
|
||||
new_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][kv_idx],
|
||||
legacy_cache_converted[layer_idx][kv_idx],
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
if new_cache[layer_idx][kv_idx] != []:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][kv_idx],
|
||||
legacy_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_static_cache(self):
|
||||
@ -1960,8 +1966,12 @@ class GenerationTesterMixin:
|
||||
|
||||
# passing past key values of different type should raise Error
|
||||
with self.assertRaises(ValueError):
|
||||
num_hidden_layers = config.get_text_config().num_hidden_layers
|
||||
model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_valyes=DynamicCache(num_hidden_layers),
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
|
||||
@ -2004,6 +2014,12 @@ class GenerationTesterMixin:
|
||||
"max_new_tokens": 10,
|
||||
}
|
||||
|
||||
max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
|
||||
config = config.get_text_config()
|
||||
past_key_values = StaticCache(
|
||||
config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
|
||||
)
|
||||
|
||||
for model_inputs in input_ids_sets:
|
||||
# eager dynamic cache
|
||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
||||
@ -2013,7 +2029,9 @@ class GenerationTesterMixin:
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
output_compiled = compiled_generate(
|
||||
model_inputs, generation_config=generation_config, past_key_values=past_key_values
|
||||
)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
|
0
tests/models/mllama/__init__.py
Normal file
0
tests/models/mllama/__init__.py
Normal file
355
tests/models/mllama/test_image_processing_mllama.py
Normal file
355
tests/models/mllama/test_image_processing_mllama.py
Normal file
@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import MllamaImageProcessor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class MllamaImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
num_images=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_convert_rgb=True,
|
||||
do_pad=True,
|
||||
max_image_tiles=4,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.max_image_tiles = max_image_tiles
|
||||
self.image_size = image_size
|
||||
self.num_images = num_images
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_pad = do_pad
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_rescale": self.do_rescale,
|
||||
"rescale_factor": self.rescale_factor,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_pad": self.do_pad,
|
||||
"max_image_tiles": self.max_image_tiles,
|
||||
}
|
||||
|
||||
def prepare_image_inputs(
|
||||
self,
|
||||
batch_size=None,
|
||||
min_resolution=None,
|
||||
max_resolution=None,
|
||||
num_channels=None,
|
||||
num_images=None,
|
||||
size_divisor=None,
|
||||
equal_resolution=False,
|
||||
numpify=False,
|
||||
torchify=False,
|
||||
):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
|
||||
One can specify whether the images are of the same resolution or not.
|
||||
"""
|
||||
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
|
||||
|
||||
batch_size = batch_size if batch_size is not None else self.batch_size
|
||||
min_resolution = min_resolution if min_resolution is not None else self.min_resolution
|
||||
max_resolution = max_resolution if max_resolution is not None else self.max_resolution
|
||||
num_channels = num_channels if num_channels is not None else self.num_channels
|
||||
num_images = num_images if num_images is not None else self.num_images
|
||||
|
||||
images_list = []
|
||||
for i in range(batch_size):
|
||||
images = []
|
||||
for j in range(num_images):
|
||||
if equal_resolution:
|
||||
width = height = max_resolution
|
||||
else:
|
||||
# To avoid getting image width/height 0
|
||||
if size_divisor is not None:
|
||||
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
|
||||
min_resolution = max(size_divisor, min_resolution)
|
||||
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
|
||||
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
|
||||
images_list.append(images)
|
||||
|
||||
if not numpify and not torchify:
|
||||
# PIL expects the channel dimension as last dimension
|
||||
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
|
||||
|
||||
if torchify:
|
||||
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
|
||||
|
||||
return images_list
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
expected_output_image_shape = (
|
||||
max(len(images) for images in images),
|
||||
self.max_image_tiles,
|
||||
self.num_channels,
|
||||
self.size["height"],
|
||||
self.size["width"],
|
||||
)
|
||||
return expected_output_image_shape
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class MllamaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = MllamaImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = MllamaImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "max_image_tiles"))
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for sample_images in image_inputs:
|
||||
for image in sample_images:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
expected_output_image_shape = (
|
||||
max(len(images) for images in image_inputs),
|
||||
self.image_processor_tester.max_image_tiles,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.size["height"],
|
||||
self.image_processor_tester.size["width"],
|
||||
)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for images in image_inputs:
|
||||
for image in images:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
for images in image_inputs:
|
||||
for image in images:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
self.skipTest("4 channels input is not supported yet")
|
||||
|
||||
def test_image_correctly_tiled(self):
|
||||
def get_empty_tiles(pixel_values):
|
||||
# image has shape batch_size, max_num_images, max_image_tiles, num_channels, height, width
|
||||
# we want to get a binary mask of shape batch_size, max_num_images, max_image_tiles
|
||||
# of empty tiles, i.e. tiles that are completely zero
|
||||
return np.all(pixel_values == 0, axis=(3, 4, 5))
|
||||
|
||||
image_processor_dict = {**self.image_processor_dict, "size": {"height": 50, "width": 50}, "max_image_tiles": 4}
|
||||
image_processor = self.image_processing_class(**image_processor_dict)
|
||||
|
||||
# image fits 2x2 tiles grid (width x height)
|
||||
image = Image.new("RGB", (80, 95))
|
||||
inputs = image_processor(image, return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values)[0, 0].tolist()
|
||||
self.assertEqual(empty_tiles, [False, False, False, False])
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids[0, 0]
|
||||
self.assertEqual(aspect_ratio_ids, 6)
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask[0, 0].tolist()
|
||||
self.assertEqual(aspect_ratio_mask, [1, 1, 1, 1])
|
||||
|
||||
# image fits 3x1 grid (width x height)
|
||||
image = Image.new("RGB", (101, 50))
|
||||
inputs = image_processor(image, return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values)[0, 0].tolist()
|
||||
self.assertEqual(empty_tiles, [False, False, False, True])
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids[0, 0]
|
||||
self.assertEqual(aspect_ratio_ids, 3)
|
||||
num_tiles = inputs.aspect_ratio_mask[0, 0].sum()
|
||||
self.assertEqual(num_tiles, 3)
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask[0, 0].tolist()
|
||||
self.assertEqual(aspect_ratio_mask, [1, 1, 1, 0])
|
||||
|
||||
# image fits 1x1 grid (width x height)
|
||||
image = Image.new("RGB", (20, 39))
|
||||
inputs = image_processor(image, return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values)[0, 0].tolist()
|
||||
self.assertEqual(empty_tiles, [False, True, True, True])
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids[0, 0]
|
||||
self.assertEqual(aspect_ratio_ids, 1)
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask[0, 0].tolist()
|
||||
self.assertEqual(aspect_ratio_mask, [1, 0, 0, 0])
|
||||
|
||||
# image fits 2x1 grid (width x height)
|
||||
image = Image.new("RGB", (51, 20))
|
||||
inputs = image_processor(image, return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values)[0, 0].tolist()
|
||||
self.assertEqual(empty_tiles, [False, False, True, True])
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids[0, 0]
|
||||
self.assertEqual(aspect_ratio_ids, 2)
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask[0, 0].tolist()
|
||||
self.assertEqual(aspect_ratio_mask, [1, 1, 0, 0])
|
||||
|
||||
# image is greater than 2x2 tiles grid (width x height)
|
||||
image = Image.new("RGB", (150, 150))
|
||||
inputs = image_processor(image, return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values)[0, 0].tolist()
|
||||
self.assertEqual(empty_tiles, [False, False, False, False])
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids[0, 0]
|
||||
self.assertEqual(aspect_ratio_ids, 6) # (2 - 1) * 4 + 2 = 6
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask[0, 0].tolist()
|
||||
self.assertEqual(aspect_ratio_mask, [1, 1, 1, 1])
|
||||
|
||||
# batch of images
|
||||
image1 = Image.new("RGB", (80, 95))
|
||||
image2 = Image.new("RGB", (101, 50))
|
||||
image3 = Image.new("RGB", (23, 49))
|
||||
inputs = image_processor([[image1], [image2, image3]], return_tensors="np")
|
||||
pixel_values = inputs.pixel_values
|
||||
empty_tiles = get_empty_tiles(pixel_values).tolist()
|
||||
expected_empty_tiles = [
|
||||
# sample 1 with 1 image 2x2 grid
|
||||
[
|
||||
[False, False, False, False],
|
||||
[True, True, True, True], # padding
|
||||
],
|
||||
# sample 2
|
||||
[
|
||||
[False, False, False, True], # 3x1
|
||||
[False, True, True, True], # 1x1
|
||||
],
|
||||
]
|
||||
self.assertEqual(empty_tiles, expected_empty_tiles)
|
||||
aspect_ratio_ids = inputs.aspect_ratio_ids.tolist()
|
||||
expected_aspect_ratio_ids = [[6, 0], [3, 1]]
|
||||
self.assertEqual(aspect_ratio_ids, expected_aspect_ratio_ids)
|
||||
aspect_ratio_mask = inputs.aspect_ratio_mask.tolist()
|
||||
expected_aspect_ratio_mask = [
|
||||
[
|
||||
[1, 1, 1, 1],
|
||||
[1, 0, 0, 0],
|
||||
],
|
||||
[
|
||||
[1, 1, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
],
|
||||
]
|
||||
self.assertEqual(aspect_ratio_mask, expected_aspect_ratio_mask)
|
642
tests/models/mllama/test_modeling_mllama.py
Normal file
642
tests/models/mllama/test_modeling_mllama.py
Normal file
@ -0,0 +1,642 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Mllama model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
BitsAndBytesConfig,
|
||||
MllamaConfig,
|
||||
MllamaForCausalLM,
|
||||
MllamaForConditionalGeneration,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class MllamaText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
text_config={
|
||||
"model_type": "mllama",
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"hidden_act": "gelu",
|
||||
"max_position_embeddings": 512,
|
||||
"initializer_range": 0.02,
|
||||
"rope_scaling": {"rope_type": "default"},
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.text_config = text_config
|
||||
self.seq_length = seq_length
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
self.is_training = is_training
|
||||
self.pad_token_id = self.text_config["pad_token_id"]
|
||||
self.batch_size = 3
|
||||
|
||||
def get_config(self):
|
||||
return MllamaTextConfig(**self.text_config)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_ids, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_mllama_model_fp16_forward(self, config, input_ids, attention_mask):
|
||||
model = MllamaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `MllamaForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_torch_compile_test_ckpt = "nltpt/Llama-3.2-11B-Vision"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MllamaText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MllamaTextConfig, has_text_modality=True)
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
super().test_eager_matches_sdpa_generate()
|
||||
|
||||
@unittest.skip(reason="The outputs don't match, no idea why")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Quanto test is borken")
|
||||
def test_generate_with_quant_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
class MllamaVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
image_token_index=4,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
text_config={
|
||||
"model_type": "mllama",
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"hidden_act": "gelu",
|
||||
"max_position_embeddings": 512,
|
||||
"initializer_range": 0.02,
|
||||
"rope_scaling": {"rope_type": "default"},
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"cross_attention_layers": [1],
|
||||
},
|
||||
vision_config={
|
||||
"image_size": 30,
|
||||
"patch_size": 2,
|
||||
"num_channels": 3,
|
||||
"hidden_size": 16,
|
||||
"intermediate_layers_indices": [0],
|
||||
"vision_output_dim": 32,
|
||||
"projection_dim": 32,
|
||||
"num_hidden_layers": 6,
|
||||
"num_global_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"dropout": 0.1,
|
||||
"initializer_range": 0.02,
|
||||
"supported_aspect_ratios": [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1], [2, 2], [3, 1], [4, 1]],
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.is_training = is_training
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.seq_length = seq_length
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
self.pad_token_id = self.text_config["pad_token_id"]
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 224
|
||||
self.max_num_images = 1
|
||||
self.max_image_tiles = 4
|
||||
|
||||
def get_config(self):
|
||||
return MllamaConfig(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
image_token_index=self.image_token_index,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.max_num_images,
|
||||
self.max_image_tiles,
|
||||
self.vision_config["num_channels"],
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["image_size"],
|
||||
]
|
||||
)
|
||||
aspect_ratio_ids = torch.tensor([[6] * self.batch_size], device=torch_device).transpose(0, 1)
|
||||
aspect_ratio_mask = torch.ones(self.batch_size, self.max_num_images, self.max_image_tiles)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, aspect_ratio_ids, aspect_ratio_mask = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
aspect_ratio_mask = aspect_ratio_mask.to(torch_device)
|
||||
cross_attention_mask = torch.ones(
|
||||
(self.batch_size, self.seq_length, self.max_num_images, self.max_image_tiles), device=torch_device
|
||||
)
|
||||
|
||||
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||
input_ids[:, 1] = config.image_token_index
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"aspect_ratio_ids": aspect_ratio_ids,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"aspect_ratio_mask": aspect_ratio_mask,
|
||||
"cross_attention_mask": cross_attention_mask,
|
||||
"use_cache": True,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_mllama_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
|
||||
model = MllamaForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values.to(torch.bfloat16),
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `MllamaForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MllamaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MllamaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
super().test_eager_matches_sdpa_generate()
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference_1_bfloat16(self):
|
||||
# A workaround to override parametrized test with flaky decorator
|
||||
super().test_eager_matches_sdpa_inference_1_bfloat16()
|
||||
|
||||
@unittest.skip(reason="Static cache not supported")
|
||||
def test_static_cache_matches_dynamic(self):
|
||||
# TypeError: list indices must be integers or slices, not tuple
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama has dynamic control flow which is not yet supported by compile")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The outputs don't match, no idea why")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama is not yet supported by compile")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
# TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
|
||||
# relevant issue: https://github.com/pytorch/pytorch/issues/133166
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp
|
||||
def test_generate_with_quant_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AssertionError: Items in the second set but not the first: might be a setting issue")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_compile_cuda_graph_time(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_torch_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Device side assert triggered")
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_sample_generate_dict_output():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_search_generate_dict_output():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_constrained_beam_search_generate_dict_output():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_dola_decoding_sample():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_generate_methods_with_num_logits_to_keep():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_greedy_generate_dict_outputs():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_group_beam_search_generate_dict_output():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_model_parallel_beam_search():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_new_cache_format_2():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_sample_generate_dict_output():
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.base_model_checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
self.instruct_model_checkpoint = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_11b_model_integration_generate(self):
|
||||
# Prepare inputs
|
||||
processor = AutoProcessor.from_pretrained(self.base_model_checkpoint)
|
||||
|
||||
prompt = "<|image|>If I had to write a haiku for this one"
|
||||
url = "https://llava-vl.github.io/static/images/view.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# Check inputs ids
|
||||
expected_input_ids = torch.tensor([[128256, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342, 369, 420, 832]], device=torch_device) # fmt: skip
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], expected_input_ids))
|
||||
|
||||
# Load model in 4 bit
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.base_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_11b_model_integration_generate_text_only(self):
|
||||
# Prepare inputs
|
||||
processor = AutoProcessor.from_pretrained(self.base_model_checkpoint)
|
||||
prompt = "If I had to write a haiku"
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
# Check inputs ids
|
||||
expected_input_ids = [128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342]
|
||||
self.assertEqual(inputs["input_ids"].cpu().squeeze().tolist(), expected_input_ids)
|
||||
|
||||
# Load model in 4 bit
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.base_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "If I had to write a haiku about my life, I think it would be something like:\n\"Life is a messy stream\nTwists and turns, ups" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_11b_model_integration_forward(self):
|
||||
# Prepare inputs
|
||||
processor = AutoProcessor.from_pretrained(self.base_model_checkpoint)
|
||||
|
||||
prompt = "<|image|>If I had to write a haiku for this one"
|
||||
url = "https://llava-vl.github.io/static/images/view.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# Load model in 4 bit
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.base_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
# Forward
|
||||
with torch.inference_mode():
|
||||
output = model(**inputs)
|
||||
|
||||
actual_logits = output.logits[0, -1, :5].cpu()
|
||||
expected_logits = torch.tensor([8.3594, 7.7148, 4.7266, 0.7803, 3.1504])
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_logits, expected_logits, atol=0.1),
|
||||
f"Actual logits: {actual_logits}"
|
||||
f"\nExpected logits: {expected_logits}"
|
||||
f"\nDifference: {torch.abs(actual_logits - expected_logits)}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_11b_model_integration_batched_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.base_model_checkpoint)
|
||||
|
||||
# Prepare inputs
|
||||
prompt = [
|
||||
"<|image|>If I had to write a haiku for this one",
|
||||
"<|image|>This image shows",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=[[image1], [image2]], padding=True, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# Load model in 4 bit
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.base_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
# Check second output
|
||||
decoded_output = processor.decode(output[1], skip_special_tokens=True)
|
||||
expected_output = "This image shows is a photograph of a stop sign in front of a Chinese archway. The stop sign is red with white letters and is" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_11b_model_integration_multi_image_generate(self):
|
||||
processor = AutoProcessor.from_pretrained(self.instruct_model_checkpoint)
|
||||
|
||||
# Prepare inputs
|
||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image2 = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What’s shown in this image?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "This image shows a long wooden dock extending out into a lake."}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What about this one, what do you see here? Can you describe in detail?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = processor(text=prompt, images=[[image1, image2]], return_tensors="pt").to(torch_device)
|
||||
prompt_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
# Load model in 4 bit
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = MllamaForConditionalGeneration.from_pretrained(
|
||||
self.instruct_model_checkpoint, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
|
||||
|
||||
# Check first output
|
||||
generated_output = output[0][prompt_len:]
|
||||
decoded_output = processor.decode(generated_output, skip_special_tokens=False)
|
||||
|
||||
# model should response about "stop sign", however it responses about "dock"
|
||||
# this happens only in quantized version, bfloat16 works fine
|
||||
expected_output = "This image shows a long wooden dock extending out into a lake. The dock is made of wooden planks and has a railing"
|
||||
|
||||
self.assertEqual(
|
||||
decoded_output,
|
||||
expected_output,
|
||||
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
|
||||
)
|
179
tests/models/mllama/test_processor_mllama.py
Normal file
179
tests/models/mllama/test_processor_mllama.py
Normal file
@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MllamaProcessor
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class MllamaProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.checkpoint = "hf-internal-testing/mllama-11b" # TODO: change
|
||||
self.processor = MllamaProcessor.from_pretrained(self.checkpoint)
|
||||
self.image1 = Image.new("RGB", (224, 220))
|
||||
self.image2 = Image.new("RGB", (512, 128))
|
||||
self.image_token = self.processor.image_token
|
||||
self.image_token_id = self.processor.image_token_id
|
||||
self.pad_token_id = self.processor.tokenizer.pad_token_id
|
||||
self.bos_token = self.processor.bos_token
|
||||
self.bos_token_id = self.processor.tokenizer.bos_token_id
|
||||
|
||||
def test_apply_chat_template(self):
|
||||
# Message contains content which a mix of lists with images and image urls and string
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What do these images show?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The first image shows the statue of Liberty in New York."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "And who is that?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
expected_rendered = (
|
||||
"<|begin_of_text|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"<|image|><|image|>What do these images show?"
|
||||
"<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
"The first image shows the statue of Liberty in New York."
|
||||
"<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"And who is that?"
|
||||
"<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a response."},
|
||||
],
|
||||
},
|
||||
]
|
||||
input_ids = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_ids = [
|
||||
128000, # <|begin_of_text|>
|
||||
128006, # <|start_header_id|>
|
||||
9125, # "system"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
1296,
|
||||
11914,
|
||||
13, # "This is a test sentence."
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
882, # "user"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
2077,
|
||||
13, # "This is a response.",
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
78191, # "assistant"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
]
|
||||
|
||||
self.assertEqual(input_ids, expected_ids)
|
||||
|
||||
# test image in multiple locations
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image in two sentences"},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": " Test sentence "},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "ok\n"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
expected_rendered = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"Describe this image in two sentences<|image|> Test sentence <|image|>ok\n<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
input_ids = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
# fmt: off
|
||||
expected_ids = [
|
||||
128000, 128006, 882, 128007, 271, 75885, 420, 2217, 304, 1403, 23719, 128256,
|
||||
3475, 11914, 262, 128256, 564, 198, 128009, 128006, 78191, 128007, 271,
|
||||
]
|
||||
# fmt: on
|
||||
self.assertEqual(input_ids, expected_ids)
|
||||
|
||||
# text format for content
|
||||
messages_list = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Describe this image in two sentences"},
|
||||
],
|
||||
}
|
||||
]
|
||||
messages_str = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image|>Describe this image in two sentences",
|
||||
}
|
||||
]
|
||||
|
||||
rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
|
||||
rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(rendered_list, rendered_str)
|
@ -446,7 +446,7 @@ class ModelTesterMixin:
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
@ -580,7 +580,7 @@ class ModelTesterMixin:
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
@ -636,7 +636,7 @@ class ModelTesterMixin:
|
||||
def test_torch_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
@ -821,15 +821,16 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if (
|
||||
model_class.__name__
|
||||
in [
|
||||
*get_values(MODEL_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
||||
]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
continue
|
||||
with self.subTest(model_class.__name__):
|
||||
if (
|
||||
model_class.__name__
|
||||
in [
|
||||
*get_values(MODEL_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
||||
]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
@ -4081,6 +4082,7 @@ class ModelTesterMixin:
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
torch.compiler.reset()
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
@ -4127,6 +4129,7 @@ class ModelTesterMixin:
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
torch.compiler.reset()
|
||||
if "cuda" in torch_device:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
@ -4721,7 +4724,6 @@ class ModelTesterMixin:
|
||||
self.skipTest(
|
||||
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
||||
)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(f"{model_class.__name__} does not support static cache")
|
||||
@ -4756,7 +4758,7 @@ class ModelTesterMixin:
|
||||
def test_torch_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
||||
ckpt = self._torch_compile_test_ckpt
|
||||
@ -4772,7 +4774,7 @@ class ModelTesterMixin:
|
||||
model.generation_config.max_new_tokens = 4
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead")
|
||||
|
||||
input_text = "Why dogs are cute?"
|
||||
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
|
||||
|
@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
|
||||
def test_dynamic_cache_retrocompatibility(self):
|
||||
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
new_cache = DynamicCache(num_hidden_layers=10)
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
@ -83,7 +83,7 @@ class CacheTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Test 1: We can convert from legacy to new with no changes
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10)
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
|
||||
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
new_cache = DynamicCache(num_hidden_layers=10)
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
@ -240,7 +240,9 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
set_seed(0)
|
||||
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
|
||||
)
|
||||
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
@ -268,7 +270,9 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
model.device
|
||||
)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
@ -132,6 +132,13 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"t2u_variance_predictor_hidden_dim",
|
||||
"t2u_variance_predictor_kernel_size",
|
||||
],
|
||||
"MllamaTextConfig": [
|
||||
"initializer_range",
|
||||
],
|
||||
"MllamaVisionConfig": [
|
||||
"initializer_range",
|
||||
"supported_aspect_ratios",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
@ -132,6 +132,8 @@ IGNORE_NON_TESTED = (
|
||||
"SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model.
|
||||
"ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model
|
||||
"Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration.
|
||||
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
|
||||
]
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user