mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
SmolVLM2 (#36126)
* smolvlm init * updates * fixing bugs * minimal run, no checks * minimal run, no checks * passing first check + adding url support * updating video dataloading logic * fixing image logic * trying modular, but fails * modular is working, changing processor to match PR comments and general transformers logic * fixing kwargs * offloading video loading logic to image_util * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * fixing circleci code formatting errors * update * add idefics3-based tests * add keyword to all * add PreTrainedModel * updateing video loading logic * working inference * updates for PR comments * updates for PR comments * moving SmolVLMPretrainedModel higher to fix import error * CI test pass * CI test pass * removing lambda * CI test pass * CI test pass * CI test pass * CI test pass * CI test pass * CI test pass * processor tests * add example in docs * typo * fix copies * skip compile tests - sdpa for VisionTransformer * fix init * raise import error for num2words * update doc for FA2 * more doc fix * CI * updates for PR comments * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Joshua Lochner <admin@xenova.com> * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * fixing processor -- tokenizer not defined properly, (gpt2 tokenizer), and does not have the attributes of fake image token, etc * adding smolvlm to VQA models * removing vqa auto class * Update src/transformers/models/smolvlm/processing_smolvlm.py Co-authored-by: Joshua Lochner <admin@xenova.com> * removing smolvlmvisiontransformer from index.md * my bad, video processing had typos * fixing docs * renaming params in SmolVLMModel.inputs_merger * removing un-needed dtype/device in model forward * ruff for CI * update docs * Update docs/source/en/model_doc/smolvlm.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * return cache position * return cache position * return cache also in modular * needed to run modular again * fix training tests * push vectorized inputs merger * format * format * reduce number of mappings * addressing PR comments * happy CI, happy me :) * skip non-nested images * adjust integration test for smaller GPUs * format * fix kwargs in chat template apply * skip this for now --------- Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
parent
f2ab182dca
commit
4397dfcb71
@ -317,6 +317,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [SEW](model_doc/sew) | ✅ | ❌ | ❌ |
|
| [SEW](model_doc/sew) | ✅ | ❌ | ❌ |
|
||||||
| [SEW-D](model_doc/sew-d) | ✅ | ❌ | ❌ |
|
| [SEW-D](model_doc/sew-d) | ✅ | ❌ | ❌ |
|
||||||
| [SigLIP](model_doc/siglip) | ✅ | ❌ | ❌ |
|
| [SigLIP](model_doc/siglip) | ✅ | ❌ | ❌ |
|
||||||
|
| [SmolVLM](model_doc/smolvlm) | ✅ | ❌ | ❌ |
|
||||||
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) | ✅ | ❌ | ✅ |
|
| [Speech Encoder decoder](model_doc/speech-encoder-decoder) | ✅ | ❌ | ✅ |
|
||||||
| [Speech2Text](model_doc/speech_to_text) | ✅ | ✅ | ❌ |
|
| [Speech2Text](model_doc/speech_to_text) | ✅ | ✅ | ❌ |
|
||||||
| [SpeechT5](model_doc/speecht5) | ✅ | ❌ | ❌ |
|
| [SpeechT5](model_doc/speecht5) | ✅ | ❌ | ❌ |
|
||||||
|
197
docs/source/en/model_doc/smolvlm.md
Normal file
197
docs/source/en/model_doc/smolvlm.md
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# SmolVLM
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
SmolVLM2 is an adaptation of the Idefics3 model with two main differences:
|
||||||
|
|
||||||
|
- It uses SmolLM2 for the text model.
|
||||||
|
- It supports multi-image and video inputs
|
||||||
|
|
||||||
|
## Usage tips
|
||||||
|
|
||||||
|
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
|
||||||
|
|
||||||
|
Videos should not be upsampled.
|
||||||
|
|
||||||
|
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*512 pixels by default.
|
||||||
|
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 512}` is the default, but you can change it to a different value if needed.
|
||||||
|
|
||||||
|
Here’s how to control resizing and set a custom size:
|
||||||
|
```python
|
||||||
|
image_processor = SmolVLMImageProcessor(do_resize=True, size={"longest_edge": 2 * 512}, max_image_size=512)
|
||||||
|
```
|
||||||
|
|
||||||
|
Additionally, the `max_image_size` parameter, which controls the size of each square patch the image is decomposed into, is set to 512 by default but can be adjusted as needed. After resizing (if applicable), the image processor decomposes the images into square patches based on the `max_image_size` parameter.
|
||||||
|
|
||||||
|
This model was contributed by [orrzohar](https://huggingface.co/orrzohar).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Usage example
|
||||||
|
|
||||||
|
### Single Media inference
|
||||||
|
|
||||||
|
The model can accept both images and videos as input, but you should use only one of the modalities at a time. Here's an example code for that.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
|
||||||
|
model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content":[
|
||||||
|
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||||
|
generated_texts = processor.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
print(generated_texts)
|
||||||
|
|
||||||
|
|
||||||
|
# Video
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "video", "path": "/path/to/video.mp4"},
|
||||||
|
{"type": "text", "text": "Describe this video in detail"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=100)
|
||||||
|
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
print(generated_texts[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batch Mixed Media Inference
|
||||||
|
|
||||||
|
The model can batch inputs composed of several images/videos and text. Here is an example.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
|
||||||
|
model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conversation for the first image
|
||||||
|
conversation1 = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "path": "/path/to/image.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Conversation with two images
|
||||||
|
conversation2 = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "path": "/path/to/image.jpg"},
|
||||||
|
{"type": "image", "path": "/path/to/image.jpg"},
|
||||||
|
{"type": "text", "text": "What is written in the pictures?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Conversation with pure text
|
||||||
|
conversation3 = [
|
||||||
|
{"role": "user","content": "who are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
conversations = [conversation1, conversation2, conversation3]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=100)
|
||||||
|
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
print(generated_texts[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
## SmolVLMConfig
|
||||||
|
|
||||||
|
[[autodoc]] SmolVLMConfig
|
||||||
|
|
||||||
|
## SmolVLMVisionConfig
|
||||||
|
|
||||||
|
[[autodoc]] SmolVLMVisionConfig
|
||||||
|
|
||||||
|
## Idefics3VisionTransformer
|
||||||
|
|
||||||
|
[[autodoc]] SmolVLMVisionTransformer
|
||||||
|
|
||||||
|
## SmolVLMModel
|
||||||
|
|
||||||
|
[[autodoc]] SmolVLMModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## SmolVLMForConditionalGeneration
|
||||||
|
|
||||||
|
[[autodoc]] SmolVLMForConditionalGeneration
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## SmolVLMImageProcessor
|
||||||
|
[[autodoc]] SmolVLMImageProcessor
|
||||||
|
- preprocess
|
||||||
|
|
||||||
|
|
||||||
|
## SmolVLMProcessor
|
||||||
|
[[autodoc]] SmolVLMProcessor
|
||||||
|
- __call__
|
@ -95,6 +95,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||||||
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
|
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
|
||||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||||
|
* [SmolVLM](https://huggingface.co/docs/transformers/model_doc/smolvlm#transformers.SmolVLMModel)
|
||||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||||
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
|
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
|
||||||
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
||||||
@ -301,6 +302,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||||
|
* [SmolVLM](https://huggingface.co/docs/transformers/model_doc/smolvlm#transformers.SmolVLMModel)
|
||||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||||
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
|
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
|
||||||
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import torch
|
|
||||||
from transformers import pipeline, AutoTokenizer, AutoModel, AutoModelForMaskedLM
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoModel, AutoTokenizer, pipeline
|
||||||
|
|
||||||
|
|
||||||
test_sentence = 'Do you [MASK] the muffin man?'
|
test_sentence = 'Do you [MASK] the muffin man?'
|
||||||
|
|
||||||
# for comparison
|
# for comparison
|
||||||
|
@ -776,6 +776,7 @@ _import_structure = {
|
|||||||
"SiglipTextConfig",
|
"SiglipTextConfig",
|
||||||
"SiglipVisionConfig",
|
"SiglipVisionConfig",
|
||||||
],
|
],
|
||||||
|
"models.smolvlm": ["SmolVLMConfig"],
|
||||||
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
|
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
|
||||||
"models.speech_to_text": [
|
"models.speech_to_text": [
|
||||||
"Speech2TextConfig",
|
"Speech2TextConfig",
|
||||||
@ -1288,6 +1289,7 @@ else:
|
|||||||
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
|
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
|
||||||
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
|
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
|
||||||
_import_structure["models.siglip"].append("SiglipImageProcessor")
|
_import_structure["models.siglip"].append("SiglipImageProcessor")
|
||||||
|
_import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"])
|
||||||
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
|
_import_structure["models.superglue"].extend(["SuperGlueImageProcessor"])
|
||||||
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
|
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
|
||||||
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
|
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
|
||||||
@ -3557,6 +3559,16 @@ else:
|
|||||||
"SiglipVisionModel",
|
"SiglipVisionModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.smolvlm"].extend(
|
||||||
|
[
|
||||||
|
"SmolVLMForConditionalGeneration",
|
||||||
|
"SmolVLMModel",
|
||||||
|
"SmolVLMPreTrainedModel",
|
||||||
|
"SmolVLMProcessor",
|
||||||
|
"SmolVLMVisionConfig",
|
||||||
|
"SmolVLMVisionTransformer",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.speech_encoder_decoder"].extend(["SpeechEncoderDecoderModel"])
|
_import_structure["models.speech_encoder_decoder"].extend(["SpeechEncoderDecoderModel"])
|
||||||
_import_structure["models.speech_to_text"].extend(
|
_import_structure["models.speech_to_text"].extend(
|
||||||
[
|
[
|
||||||
@ -5930,6 +5942,7 @@ if TYPE_CHECKING:
|
|||||||
SiglipTextConfig,
|
SiglipTextConfig,
|
||||||
SiglipVisionConfig,
|
SiglipVisionConfig,
|
||||||
)
|
)
|
||||||
|
from .models.smolvlm import SmolVLMConfig
|
||||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
|
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
|
||||||
from .models.speech_to_text import (
|
from .models.speech_to_text import (
|
||||||
Speech2TextConfig,
|
Speech2TextConfig,
|
||||||
@ -6459,6 +6472,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
|
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
|
||||||
from .models.seggpt import SegGptImageProcessor
|
from .models.seggpt import SegGptImageProcessor
|
||||||
from .models.siglip import SiglipImageProcessor
|
from .models.siglip import SiglipImageProcessor
|
||||||
|
from .models.smolvlm import SmolVLMImageProcessor
|
||||||
from .models.superglue import SuperGlueImageProcessor
|
from .models.superglue import SuperGlueImageProcessor
|
||||||
from .models.superpoint import SuperPointImageProcessor
|
from .models.superpoint import SuperPointImageProcessor
|
||||||
from .models.swin2sr import Swin2SRImageProcessor
|
from .models.swin2sr import Swin2SRImageProcessor
|
||||||
@ -8274,6 +8288,14 @@ if TYPE_CHECKING:
|
|||||||
SiglipTextModel,
|
SiglipTextModel,
|
||||||
SiglipVisionModel,
|
SiglipVisionModel,
|
||||||
)
|
)
|
||||||
|
from .models.smolvlm import (
|
||||||
|
SmolVLMForConditionalGeneration,
|
||||||
|
SmolVLMModel,
|
||||||
|
SmolVLMPreTrainedModel,
|
||||||
|
SmolVLMProcessor,
|
||||||
|
SmolVLMVisionConfig,
|
||||||
|
SmolVLMVisionTransformer,
|
||||||
|
)
|
||||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
|
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
|
||||||
from .models.speech_to_text import (
|
from .models.speech_to_text import (
|
||||||
Speech2TextForConditionalGeneration,
|
Speech2TextForConditionalGeneration,
|
||||||
|
@ -843,7 +843,7 @@ def load_video(
|
|||||||
file_obj = BytesIO(requests.get(video).content)
|
file_obj = BytesIO(requests.get(video).content)
|
||||||
elif os.path.isfile(video):
|
elif os.path.isfile(video):
|
||||||
file_obj = video
|
file_obj = video
|
||||||
elif is_valid_image(video) or (isinstance(video, (list, tuple) and is_valid_image(video[0]))):
|
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
|
||||||
file_obj = None
|
file_obj = None
|
||||||
else:
|
else:
|
||||||
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
||||||
|
@ -245,6 +245,7 @@ from . import (
|
|||||||
sew,
|
sew,
|
||||||
sew_d,
|
sew_d,
|
||||||
siglip,
|
siglip,
|
||||||
|
smolvlm,
|
||||||
speech_encoder_decoder,
|
speech_encoder_decoder,
|
||||||
speech_to_text,
|
speech_to_text,
|
||||||
speecht5,
|
speecht5,
|
||||||
|
@ -272,6 +272,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("sew-d", "SEWDConfig"),
|
("sew-d", "SEWDConfig"),
|
||||||
("siglip", "SiglipConfig"),
|
("siglip", "SiglipConfig"),
|
||||||
("siglip_vision_model", "SiglipVisionConfig"),
|
("siglip_vision_model", "SiglipVisionConfig"),
|
||||||
|
("smolvlm", "SmolVLMConfig"),
|
||||||
|
("smolvlm_vision", "SmolVLMVisionConfig"),
|
||||||
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
||||||
("speech_to_text", "Speech2TextConfig"),
|
("speech_to_text", "Speech2TextConfig"),
|
||||||
("speech_to_text_2", "Speech2Text2Config"),
|
("speech_to_text_2", "Speech2Text2Config"),
|
||||||
@ -616,6 +618,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("sew-d", "SEW-D"),
|
("sew-d", "SEW-D"),
|
||||||
("siglip", "SigLIP"),
|
("siglip", "SigLIP"),
|
||||||
("siglip_vision_model", "SiglipVisionModel"),
|
("siglip_vision_model", "SiglipVisionModel"),
|
||||||
|
("smolvlm", "SmolVLM"),
|
||||||
|
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||||
("speech-encoder-decoder", "Speech Encoder decoder"),
|
("speech-encoder-decoder", "Speech Encoder decoder"),
|
||||||
("speech_to_text", "Speech2Text"),
|
("speech_to_text", "Speech2Text"),
|
||||||
("speech_to_text_2", "Speech2Text2"),
|
("speech_to_text_2", "Speech2Text2"),
|
||||||
@ -741,6 +745,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
|||||||
("aria_text", "aria"),
|
("aria_text", "aria"),
|
||||||
("idefics3_vision", "idefics3"),
|
("idefics3_vision", "idefics3"),
|
||||||
("siglip_vision_model", "siglip"),
|
("siglip_vision_model", "siglip"),
|
||||||
|
("smolvlm_vision", "smolvlm"),
|
||||||
("chinese_clip_vision_model", "chinese_clip"),
|
("chinese_clip_vision_model", "chinese_clip"),
|
||||||
("rt_detr_resnet", "rt_detr"),
|
("rt_detr_resnet", "rt_detr"),
|
||||||
("granitevision", "llava_next"),
|
("granitevision", "llava_next"),
|
||||||
|
@ -251,6 +251,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("sew-d", "SEWDModel"),
|
("sew-d", "SEWDModel"),
|
||||||
("siglip", "SiglipModel"),
|
("siglip", "SiglipModel"),
|
||||||
("siglip_vision_model", "SiglipVisionModel"),
|
("siglip_vision_model", "SiglipVisionModel"),
|
||||||
|
("smolvlm", "SmolVLMModel"),
|
||||||
|
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||||
("speech_to_text", "Speech2TextModel"),
|
("speech_to_text", "Speech2TextModel"),
|
||||||
("speecht5", "SpeechT5Model"),
|
("speecht5", "SpeechT5Model"),
|
||||||
("splinter", "SplinterModel"),
|
("splinter", "SplinterModel"),
|
||||||
@ -835,6 +837,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
|||||||
("pixtral", "LlavaForConditionalGeneration"),
|
("pixtral", "LlavaForConditionalGeneration"),
|
||||||
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
|
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
|
||||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||||
|
("smolvlm", "SmolVLMForConditionalGeneration"),
|
||||||
("udop", "UdopForConditionalGeneration"),
|
("udop", "UdopForConditionalGeneration"),
|
||||||
("vipllava", "VipLlavaForConditionalGeneration"),
|
("vipllava", "VipLlavaForConditionalGeneration"),
|
||||||
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
|
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
|
||||||
|
@ -22,12 +22,12 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ... import PreTrainedModel
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
29
src/transformers/models/smolvlm/__init__.py
Normal file
29
src/transformers/models/smolvlm/__init__.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_smolvlm import *
|
||||||
|
from .image_processing_smolvlm import *
|
||||||
|
from .modeling_smolvlm import *
|
||||||
|
from .processing_smolvlm import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
197
src/transformers/models/smolvlm/configuration_smolvlm.py
Normal file
197
src/transformers/models/smolvlm/configuration_smolvlm.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_smolvlm.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Written by Orr Zohar
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a
|
||||||
|
SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
|
||||||
|
[google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM
|
||||||
|
[HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct).
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1152):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
Number of channels in the input images.
|
||||||
|
image_size (`int`, *optional*, defaults to 224):
|
||||||
|
The size (resolution) of each image.
|
||||||
|
patch_size (`int`, *optional*, defaults to 32):
|
||||||
|
The size (resolution) of each patch.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
||||||
|
>>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig
|
||||||
|
|
||||||
|
>>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration
|
||||||
|
>>> configuration = SmolVLMVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration
|
||||||
|
>>> model = SmolVLMVisionTransformer(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "smolvlm_vision"
|
||||||
|
base_config_key = "vision_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1152,
|
||||||
|
intermediate_size=3072,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=32,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
layer_norm_eps=1e-6,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a
|
||||||
|
SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM
|
||||||
|
[HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should cache the key/value pairs of the attention mechanism. Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
image_token_id (`int`, *optional*, defaults to 128257):
|
||||||
|
The id of the "image" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to tie the word embeddings with the token embeddings.
|
||||||
|
vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
|
||||||
|
Custom vision config or dict for the vision tower
|
||||||
|
text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
|
||||||
|
Custom text config or dict for the text model
|
||||||
|
scale_factor (`int`, *optional*, defaults to 2):
|
||||||
|
The scale factor for the image encoder.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 128002):
|
||||||
|
The id of the padding token.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import SmolVLMModel, SmolVLMConfig
|
||||||
|
>>> # Initializing configuration
|
||||||
|
>>> configuration = SmolVLMConfig()
|
||||||
|
>>> # Initializing a model from the configuration
|
||||||
|
>>> model = SmolVLMModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "smolvlm"
|
||||||
|
sub_configs = {"text_config": AutoConfig, "vision_config": SmolVLMVisionConfig}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_cache=True,
|
||||||
|
image_token_id=128257,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
scale_factor=2,
|
||||||
|
pad_token_id=128_002,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
|
||||||
|
if vision_config is None:
|
||||||
|
self.vision_config = SmolVLMVisionConfig()
|
||||||
|
logger.info("vision_config is None, using default vision config")
|
||||||
|
elif isinstance(vision_config, dict):
|
||||||
|
self.vision_config = SmolVLMVisionConfig(**vision_config)
|
||||||
|
elif isinstance(vision_config, SmolVLMVisionConfig):
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
||||||
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
logger.info("text_config is None, using default text config")
|
||||||
|
text_config = CONFIG_MAPPING["llama"](
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
super().__init__(**kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["SmolVLMVisionConfig", "SmolVLMConfig"]
|
851
src/transformers/models/smolvlm/image_processing_smolvlm.py
Normal file
851
src/transformers/models/smolvlm/image_processing_smolvlm.py
Normal file
@ -0,0 +1,851 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_smolvlm.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Written by Orr Zohar
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||||
|
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_scaled_image,
|
||||||
|
make_nested_list_of_images,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
validate_preprocess_arguments,
|
||||||
|
)
|
||||||
|
from ...utils import TensorType, is_vision_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_output_size_rescale_to_max_len(
|
||||||
|
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
min_len (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum size of the output image.
|
||||||
|
max_len (`int`, *optional*, defaults to the maximum size of the image):
|
||||||
|
Maximum size of the output image.
|
||||||
|
Returns:
|
||||||
|
The output size of the image after resizing.
|
||||||
|
"""
|
||||||
|
max_len = max(height, width) if max_len is None else max_len
|
||||||
|
aspect_ratio = width / height
|
||||||
|
|
||||||
|
if width >= height:
|
||||||
|
width = max_len
|
||||||
|
height = int(width / aspect_ratio)
|
||||||
|
if height % 2 != 0:
|
||||||
|
height += 1
|
||||||
|
elif height > width:
|
||||||
|
height = max_len
|
||||||
|
width = int(height * aspect_ratio)
|
||||||
|
if width % 2 != 0:
|
||||||
|
width += 1
|
||||||
|
|
||||||
|
# Avoid resizing to a size smaller than min_len
|
||||||
|
height = max(height, min_len)
|
||||||
|
width = max(width, min_len)
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_output_size_scale_below_upper_bound(
|
||||||
|
height: int, width: int, max_len: Optional[Dict[str, int]] = None
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
||||||
|
Defines the maximum dimensions of the image.
|
||||||
|
Returns:
|
||||||
|
The output size of the image after resizing.
|
||||||
|
"""
|
||||||
|
max_len = max(height, width) if max_len is None else max_len
|
||||||
|
|
||||||
|
aspect_ratio = width / height
|
||||||
|
if width >= height and width > max_len:
|
||||||
|
width = max_len
|
||||||
|
height = int(width / aspect_ratio)
|
||||||
|
elif height > width and height > max_len:
|
||||||
|
height = max_len
|
||||||
|
width = int(height * aspect_ratio)
|
||||||
|
|
||||||
|
# Avoid resizing to a size smaller than 1
|
||||||
|
height = max(height, 1)
|
||||||
|
width = max(width, 1)
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
resolution_max_side: int,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
resolution_max_side (`int`):
|
||||||
|
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
|
||||||
|
input aspect ratio.
|
||||||
|
input_data_format (`ChannelDimension` or `str`):
|
||||||
|
The channel dimension format of the input image.
|
||||||
|
Returns:
|
||||||
|
The output size of the image after resizing.
|
||||||
|
"""
|
||||||
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
|
||||||
|
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
||||||
|
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
|
||||||
|
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
|
||||||
|
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_height_width(
|
||||||
|
images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Get the maximum height and width across all images in a batch.
|
||||||
|
"""
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
|
||||||
|
|
||||||
|
max_height = max_width = float("-inf")
|
||||||
|
for images in images_list:
|
||||||
|
for image in images:
|
||||||
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
max_height = max(height, max_height)
|
||||||
|
max_width = max(width, max_width)
|
||||||
|
return (max_height, max_width)
|
||||||
|
|
||||||
|
|
||||||
|
def make_pixel_mask(
|
||||||
|
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to make the pixel mask for.
|
||||||
|
output_size (`Tuple[int, int]`):
|
||||||
|
Output size of the mask.
|
||||||
|
"""
|
||||||
|
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
mask = np.zeros(output_size, dtype=np.int64)
|
||||||
|
mask[:input_height, :input_width] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_rgb(
|
||||||
|
image: np.ndarray,
|
||||||
|
palette: Optional[PIL.ImagePalette.ImagePalette] = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> ImageInput:
|
||||||
|
"""
|
||||||
|
Converts an image to RGB format.
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image to convert.
|
||||||
|
palette (List[int], *optional*):
|
||||||
|
The palette to use if given.
|
||||||
|
data_format (ChannelDimension or str, *optional*):
|
||||||
|
The channel dimension format for the output image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (ChannelDimension or str, *optional*):
|
||||||
|
The channel dimension format of the input image.
|
||||||
|
"""
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
|
||||||
|
|
||||||
|
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
||||||
|
# The resized image from PIL will always have channels last, so find the input format first.
|
||||||
|
data_format = input_data_format if data_format is None else data_format
|
||||||
|
|
||||||
|
mode = "P" if palette is not None else None
|
||||||
|
image = to_pil_image(image, image_mode=mode, input_data_format=input_data_format)
|
||||||
|
if image.mode == "P" and palette is not None:
|
||||||
|
image.putpalette(palette)
|
||||||
|
|
||||||
|
image_rgba = image.convert("RGBA")
|
||||||
|
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
||||||
|
alpha_composite = Image.alpha_composite(background, image_rgba)
|
||||||
|
alpha_composite = alpha_composite.convert("RGB")
|
||||||
|
|
||||||
|
output_array = np.array(alpha_composite)
|
||||||
|
# The image is always in channels last format after converting from a PIL image
|
||||||
|
output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST)
|
||||||
|
return output_array
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME Amy: make a more general crop function that isn't just centre crop
|
||||||
|
def _crop(
|
||||||
|
image: np.ndarray,
|
||||||
|
w1: int,
|
||||||
|
h1: int,
|
||||||
|
w2: int,
|
||||||
|
h2: int,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if data_format is None:
|
||||||
|
data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
|
||||||
|
|
||||||
|
if data_format == ChannelDimension.FIRST:
|
||||||
|
image = image[:, h1:h2, w1:w2]
|
||||||
|
elif data_format == ChannelDimension.LAST:
|
||||||
|
image = image[h1:h2, w1:w2, :]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid channel dimension format.")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a SmolVLM image processor.
|
||||||
|
Args:
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
|
||||||
|
Only has an effect if the input image is in the PIL format.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
|
||||||
|
shortest edge resized to keep the input aspect ratio.
|
||||||
|
size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`):
|
||||||
|
Controls the size of the output image. This is a dictionary containing the key "longest_edge".
|
||||||
|
The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized
|
||||||
|
to keep the input aspect ratio.
|
||||||
|
resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
do_image_splitting (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to split the image into sub-images concatenated with the original image. They are split into patches
|
||||||
|
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
|
||||||
|
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
|
||||||
|
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `1/255`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
|
||||||
|
a standard deviation of `image_std`.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||||
|
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to pad the images to the largest height and width in the batch and number of images per
|
||||||
|
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_convert_rgb: bool = True,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
||||||
|
do_image_splitting: bool = True,
|
||||||
|
max_image_size: Dict[str, int] = None,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: float = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size if size is not None else {"longest_edge": 4 * 364}
|
||||||
|
self.resample = resample
|
||||||
|
self.do_image_splitting = do_image_splitting
|
||||||
|
self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364}
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.do_pad = do_pad
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge
|
||||||
|
resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"].
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the output image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
|
||||||
|
|
||||||
|
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
||||||
|
# The resized image from PIL will always have channels last, so find the input format first.
|
||||||
|
data_format = input_data_format if data_format is None else data_format
|
||||||
|
|
||||||
|
if "longest_edge" in size:
|
||||||
|
size = get_resize_output_image_size(
|
||||||
|
image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format
|
||||||
|
)
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
size = (size["height"], size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
||||||
|
|
||||||
|
image_mode = None
|
||||||
|
if image.ndim == 2 or image.shape[-1] == 1:
|
||||||
|
image_mode = "P"
|
||||||
|
image = to_pil_image(image, image_mode=image_mode, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
resized_image = image.resize((size[1], size[0]), resample=resample)
|
||||||
|
resized_image = np.array(resized_image)
|
||||||
|
|
||||||
|
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
||||||
|
# so we need to add it back if necessary.
|
||||||
|
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
|
||||||
|
# The image is always in channels last format after converting from a PIL image
|
||||||
|
resized_image = to_channel_dimension_format(
|
||||||
|
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
|
||||||
|
)
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def split_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
max_image_size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Split an image into squares of side max_image_size and the original image resized to max_image_size.
|
||||||
|
That means that a single image becomes a sequence of images.
|
||||||
|
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
|
||||||
|
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
|
||||||
|
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
|
||||||
|
sub-images of the same size each (image_size, image_size). Typically, 364x364.
|
||||||
|
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Images to split.
|
||||||
|
max_image_size (`Dict[str, int]`):
|
||||||
|
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||||
|
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the output image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
max_height = max_width = max_image_size["longest_edge"]
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
if height > max_height or width > max_width:
|
||||||
|
# Calculate the number of splits
|
||||||
|
num_splits_h = math.ceil(height / max_height)
|
||||||
|
num_splits_w = math.ceil(width / max_width)
|
||||||
|
# Calculate the optimal width and height for the sub-images
|
||||||
|
optimal_height = math.ceil(height / num_splits_h)
|
||||||
|
optimal_width = math.ceil(width / num_splits_w)
|
||||||
|
|
||||||
|
# Iterate through each row and column
|
||||||
|
for r in range(num_splits_h):
|
||||||
|
for c in range(num_splits_w):
|
||||||
|
# Calculate the starting point of the crop
|
||||||
|
start_x = c * optimal_width
|
||||||
|
start_y = r * optimal_height
|
||||||
|
|
||||||
|
# Calculate the ending point of the crop
|
||||||
|
end_x = min(start_x + optimal_width, width)
|
||||||
|
end_y = min(start_y + optimal_height, height)
|
||||||
|
|
||||||
|
# Crop the image
|
||||||
|
cropped_image = _crop(
|
||||||
|
image,
|
||||||
|
start_x,
|
||||||
|
start_y,
|
||||||
|
end_x,
|
||||||
|
end_y,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
frames.append(cropped_image)
|
||||||
|
|
||||||
|
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
|
||||||
|
global_image_height, global_image_width = max_height, max_width
|
||||||
|
if height != global_image_height or width != global_image_width:
|
||||||
|
image = self.resize(
|
||||||
|
image,
|
||||||
|
{"height": global_image_height, "width": global_image_width},
|
||||||
|
resample=resample,
|
||||||
|
input_data_format=data_format,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_splits_h, num_splits_w = 0, 0
|
||||||
|
|
||||||
|
frames.append(image)
|
||||||
|
|
||||||
|
return frames, num_splits_h, num_splits_w
|
||||||
|
|
||||||
|
def resize_for_vision_encoder(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
vision_encoder_max_size: int,
|
||||||
|
resample: PILImageResampling = PILImageResampling.LANCZOS,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Images to resize.
|
||||||
|
vision_encoder_max_size (`int`):
|
||||||
|
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||||
|
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the output image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred
|
||||||
|
"""
|
||||||
|
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
|
||||||
|
aspect_ratio = width / height
|
||||||
|
if width >= height:
|
||||||
|
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||||
|
height = int(width / aspect_ratio)
|
||||||
|
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||||
|
elif height > width:
|
||||||
|
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||||
|
width = int(height * aspect_ratio)
|
||||||
|
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||||
|
new_size = {"height": height, "width": width}
|
||||||
|
return self.resize(
|
||||||
|
image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pad_image(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
output_size: Tuple[int, int],
|
||||||
|
constant_values: Union[float, Iterable[float]] = 0,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Pad an image with zeros to the given size.
|
||||||
|
"""
|
||||||
|
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||||
|
output_height, output_width = output_size
|
||||||
|
|
||||||
|
pad_bottom = output_height - input_height
|
||||||
|
pad_right = output_width - input_width
|
||||||
|
padding = ((0, pad_bottom), (0, pad_right))
|
||||||
|
padded_image = pad(
|
||||||
|
image,
|
||||||
|
padding,
|
||||||
|
mode=PaddingMode.CONSTANT,
|
||||||
|
constant_values=constant_values,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
return padded_image
|
||||||
|
|
||||||
|
def pad(
|
||||||
|
self,
|
||||||
|
images: List[np.ndarray],
|
||||||
|
constant_values: Union[float, Iterable[float]] = 0,
|
||||||
|
return_pixel_mask: bool = True,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
|
||||||
|
For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
|
||||||
|
Args:
|
||||||
|
images (`List[np.ndarray]`):
|
||||||
|
List of list of images to pad. Pads to the largest height and width in the batch.
|
||||||
|
constant_values (`float` or `Iterable[float]`, *optional*):
|
||||||
|
The value to use for the padding if `mode` is `"constant"`.
|
||||||
|
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return a pixel mask.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
batch_size = len(images)
|
||||||
|
max_num_images = max(len(images_) for images_ in images)
|
||||||
|
input_data_format = (
|
||||||
|
infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4))
|
||||||
|
if input_data_format is None
|
||||||
|
else input_data_format
|
||||||
|
)
|
||||||
|
data_format = input_data_format if data_format is None else data_format
|
||||||
|
|
||||||
|
if input_data_format == ChannelDimension.FIRST:
|
||||||
|
n_channels = images[0][0].shape[0]
|
||||||
|
elif input_data_format == ChannelDimension.LAST:
|
||||||
|
n_channels = images[0][0].shape[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid channel dimension format.")
|
||||||
|
|
||||||
|
def empty_image(size, input_data_format):
|
||||||
|
if input_data_format == ChannelDimension.FIRST:
|
||||||
|
return np.zeros((n_channels, *size), dtype=np.uint8)
|
||||||
|
elif input_data_format == ChannelDimension.LAST:
|
||||||
|
return np.zeros((*size, n_channels), dtype=np.uint8)
|
||||||
|
|
||||||
|
padded_images_list = [
|
||||||
|
[empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
for sample_idx, image in enumerate(images[batch_idx]):
|
||||||
|
padded_images_list[batch_idx][sample_idx] = self._pad_image(
|
||||||
|
image,
|
||||||
|
pad_size,
|
||||||
|
constant_values=constant_values,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
padded_masks[batch_idx][sample_idx] = make_pixel_mask(
|
||||||
|
image, output_size=pad_size, input_data_format=input_data_format
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_masks = padded_masks if return_pixel_mask else None
|
||||||
|
return padded_images_list, padded_masks
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_convert_rgb: Optional[bool] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_image_splitting: Optional[bool] = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
max_image_size: Optional[Dict[str, int]] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: Optional[bool] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
return_row_col_info: bool = False,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Preprocess a batch of images.
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
A list of images to preprocess.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing. With the longest edge resized to keep the input aspect ratio.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
|
||||||
|
Whether to split the image into sub-images concatenated with the original image. They are split into patches
|
||||||
|
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
|
||||||
|
max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`):
|
||||||
|
Maximum resolution of the images. If the image is larger than this size, the image is split into patches.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||||
|
Whether or not to pad the images to the largest height and width in the batch.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
return_row_col_info (`bool`, *optional*, default to `False`):
|
||||||
|
Whether to return the number of rows and columns of the split images. This is used for the
|
||||||
|
`SmolVLMProcessor` to generate prompt strings based on the number of rows and columns.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
|
||||||
|
max_image_size = max_image_size if max_image_size is not None else self.max_image_size
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||||
|
|
||||||
|
images_list = make_nested_list_of_images(images)
|
||||||
|
|
||||||
|
if not valid_images(images_list[0]):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# save the palettes for conversion to RGB
|
||||||
|
palettes_list = [
|
||||||
|
[im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
# Extra channel dimension for grayscale images
|
||||||
|
if input_data_format in [ChannelDimension.LAST, None]:
|
||||||
|
images_list = [
|
||||||
|
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
|
||||||
|
]
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
images_list = [
|
||||||
|
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_rescale and is_scaled_image(images_list[0][0]):
|
||||||
|
logger.warning_once(
|
||||||
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||||
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_image_splitting:
|
||||||
|
# We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
|
||||||
|
# for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
|
||||||
|
# for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.resize_for_vision_encoder(
|
||||||
|
image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
images_list_split_arrays = []
|
||||||
|
palettes_list_split_arrays = []
|
||||||
|
images_list_rows = []
|
||||||
|
images_list_cols = []
|
||||||
|
for images, palettes in zip(images_list, palettes_list):
|
||||||
|
split_image_arrays = []
|
||||||
|
split_palettes_arrays = []
|
||||||
|
image_rows = []
|
||||||
|
image_cols = []
|
||||||
|
for image, palette in zip(images, palettes):
|
||||||
|
split_image_array, rows, cols = self.split_image(
|
||||||
|
image,
|
||||||
|
max_image_size=max_image_size,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
split_image_arrays.extend(split_image_array)
|
||||||
|
split_palettes_arrays.extend([palette] * len(split_image_array))
|
||||||
|
image_rows.append(rows)
|
||||||
|
image_cols.append(cols)
|
||||||
|
images_list_split_arrays.append(split_image_arrays)
|
||||||
|
palettes_list_split_arrays.append(split_palettes_arrays)
|
||||||
|
images_list_rows.append(image_rows)
|
||||||
|
images_list_cols.append(image_cols)
|
||||||
|
images_list = images_list_split_arrays
|
||||||
|
palettes_list = palettes_list_split_arrays
|
||||||
|
else:
|
||||||
|
# We square the images to max_image_size
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.resize(
|
||||||
|
image=image,
|
||||||
|
size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]},
|
||||||
|
resample=resample,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
images_list_rows = [[0] * len(images) for images in images_list]
|
||||||
|
images_list_cols = [[0] * len(images) for images in images_list]
|
||||||
|
|
||||||
|
if do_convert_rgb:
|
||||||
|
images_list = [
|
||||||
|
[convert_to_rgb(img, palette) for img, palette in zip(images, palettes)]
|
||||||
|
for images, palettes in zip(images_list, palettes_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images_list = [
|
||||||
|
[self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
pixel_attention_mask = None
|
||||||
|
if do_pad:
|
||||||
|
images_list, pixel_attention_mask = self.pad(
|
||||||
|
images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# Faster tensor conversion
|
||||||
|
data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list}
|
||||||
|
if pixel_attention_mask is not None:
|
||||||
|
data["pixel_attention_mask"] = (
|
||||||
|
np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
encoding = BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
|
||||||
|
if return_row_col_info:
|
||||||
|
encoding["rows"] = images_list_rows
|
||||||
|
encoding["cols"] = images_list_cols
|
||||||
|
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["SmolVLMImageProcessor"]
|
1265
src/transformers/models/smolvlm/modeling_smolvlm.py
Normal file
1265
src/transformers/models/smolvlm/modeling_smolvlm.py
Normal file
File diff suppressed because it is too large
Load Diff
387
src/transformers/models/smolvlm/modular_smolvlm.py
Normal file
387
src/transformers/models/smolvlm/modular_smolvlm.py
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Written by Orr Zohar
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...cache_utils import DynamicCache
|
||||||
|
from ...utils import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||||||
|
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
|
||||||
|
from ..idefics3.modeling_idefics3 import (
|
||||||
|
Idefics3BaseModelOutputWithPast,
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
|
Idefics3Model,
|
||||||
|
Idefics3PreTrainedModel,
|
||||||
|
Idefics3VisionTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionConfig(Idefics3VisionConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`SmolVLMVisionModel`]. It is used to instantiate a
|
||||||
|
SmolVLM vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
|
||||||
|
[google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) used in SmolVLM
|
||||||
|
[HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct).
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1152):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
Number of channels in the input images.
|
||||||
|
image_size (`int`, *optional*, defaults to 224):
|
||||||
|
The size (resolution) of each image.
|
||||||
|
patch_size (`int`, *optional*, defaults to 32):
|
||||||
|
The size (resolution) of each patch.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
|
||||||
|
>>> from transformers.models.smolvlm.configuration_smolvlm import SmolVLMVisionConfig
|
||||||
|
|
||||||
|
>>> # Initializing a SmolVLMVisionConfig with google/siglip-so400m-patch14-384 style configuration
|
||||||
|
>>> configuration = SmolVLMVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a SmolVLMVisionTransformer (with random weights) from the google/siglip-so400m-patch14-384 style configuration
|
||||||
|
>>> model = SmolVLMVisionTransformer(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "smolvlm_vision"
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMPreTrainedModel(Idefics3PreTrainedModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionTransformer(Idefics3VisionTransformer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMConfig(Idefics3Config):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a
|
||||||
|
SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM
|
||||||
|
[HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should cache the key/value pairs of the attention mechanism. Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
image_token_id (`int`, *optional*, defaults to 128257):
|
||||||
|
The id of the "image" token.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to tie the word embeddings with the token embeddings.
|
||||||
|
vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
|
||||||
|
Custom vision config or dict for the vision tower
|
||||||
|
text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
|
||||||
|
Custom text config or dict for the text model
|
||||||
|
scale_factor (`int`, *optional*, defaults to 2):
|
||||||
|
The scale factor for the image encoder.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 128002):
|
||||||
|
The id of the padding token.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import SmolVLMModel, SmolVLMConfig
|
||||||
|
>>> # Initializing configuration
|
||||||
|
>>> configuration = SmolVLMConfig()
|
||||||
|
>>> # Initializing a model from the configuration
|
||||||
|
>>> model = SmolVLMModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "smolvlm"
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMImageProcessor(Idefics3ImageProcessor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMBaseModelOutputWithPast(Idefics3BaseModelOutputWithPast):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMModel(Idefics3Model):
|
||||||
|
"""
|
||||||
|
A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
|
||||||
|
in forward. Instead, we override inputs_merger here with custom logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def inputs_merger(
|
||||||
|
self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
|
||||||
|
):
|
||||||
|
_, patch_size, _ = image_hidden_states.shape
|
||||||
|
|
||||||
|
image_mask = input_ids == self.image_token_id
|
||||||
|
num_image_tokens = image_mask.sum(dim=1)
|
||||||
|
if not torch.all(num_image_tokens % patch_size == 0):
|
||||||
|
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
|
||||||
|
|
||||||
|
blocks_per_sample = num_image_tokens // patch_size
|
||||||
|
|
||||||
|
offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
|
||||||
|
block_offset = offsets[:-1]
|
||||||
|
row_cum = image_mask.cumsum(dim=-1)
|
||||||
|
chunk_idx = (row_cum - 1) // patch_size
|
||||||
|
local_idx = (row_cum - 1) % patch_size
|
||||||
|
block_idx = block_offset.unsqueeze(1) + chunk_idx
|
||||||
|
|
||||||
|
image_embeds = torch.zeros_like(inputs_embeds)
|
||||||
|
image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
|
||||||
|
|
||||||
|
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||||
|
return merged_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.training and self.text_model.gradient_checkpointing and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
if use_cache:
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||||
|
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
|
||||||
|
|
||||||
|
# START VISUAL INPUTS INTEGRATION
|
||||||
|
if pixel_values is not None and image_hidden_states is not None:
|
||||||
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||||
|
elif pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values
|
||||||
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||||
|
|
||||||
|
if not any(real_images_inds):
|
||||||
|
# no images, leave one empty image.
|
||||||
|
real_images_inds[0] = True
|
||||||
|
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
).last_hidden_state
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(image_hidden_states)
|
||||||
|
|
||||||
|
elif image_hidden_states is not None:
|
||||||
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||||
|
|
||||||
|
if inputs_embeds is not None and image_hidden_states is not None:
|
||||||
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||||
|
# that simply don't exist
|
||||||
|
inputs_embeds = self.inputs_merger(
|
||||||
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
image_hidden_states=image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.text_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||||||
|
|
||||||
|
return SmolVLMBaseModelOutputWithPast(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
|
||||||
|
"""
|
||||||
|
A subclass of Idefics3ForConditionalGeneration that uses SmolVLMModel
|
||||||
|
instead of the default Idefics3Model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = SmolVLMModel(config)
|
||||||
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def forward(self, **super_kwargs):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`).
|
||||||
|
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
||||||
|
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import requests
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> from io import BytesIO
|
||||||
|
|
||||||
|
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
>>> from transformers.image_utils import load_image
|
||||||
|
|
||||||
|
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
|
||||||
|
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
||||||
|
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
|
||||||
|
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
||||||
|
>>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
|
||||||
|
|
||||||
|
>>> # Create inputs
|
||||||
|
>>> messages = [
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "video", "path": path/to/video},
|
||||||
|
... {"type": "text", "text": "What is happening in this video?"},
|
||||||
|
... ]
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
|
||||||
|
>>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
|
||||||
|
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> print(generated_texts)
|
||||||
|
```"""
|
||||||
|
super().forward(**super_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SmolVLMVisionConfig",
|
||||||
|
"SmolVLMConfig",
|
||||||
|
"SmolVLMImageProcessor",
|
||||||
|
"SmolVLMForConditionalGeneration",
|
||||||
|
"SmolVLMPreTrainedModel",
|
||||||
|
"SmolVLMModel",
|
||||||
|
"SmolVLMVisionTransformer",
|
||||||
|
]
|
454
src/transformers/models/smolvlm/processing_smolvlm.py
Normal file
454
src/transformers/models/smolvlm/processing_smolvlm.py
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Processor class for SmolVLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from ...feature_extraction_utils import BatchFeature
|
||||||
|
from ...image_utils import (
|
||||||
|
ImageInput,
|
||||||
|
VideoInput,
|
||||||
|
make_batched_videos,
|
||||||
|
make_nested_list_of_images,
|
||||||
|
)
|
||||||
|
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
|
from ...tokenization_utils_base import BatchEncoding, TextInput
|
||||||
|
from ...utils import is_num2words_available, logging
|
||||||
|
from .video_processing_smolvlm import (
|
||||||
|
DEFAULT_MEDIA_OUTTRO,
|
||||||
|
DEFAULT_VIDEO_INTRO,
|
||||||
|
FRAME_TIMESTAMP_MESSAGE,
|
||||||
|
smolvlm_sample_indices_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...tokenization_utils_base import PreTokenizedInput
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_num2words_available():
|
||||||
|
from num2words import num2words
|
||||||
|
else:
|
||||||
|
num2words = None
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_split_image(
|
||||||
|
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token
|
||||||
|
):
|
||||||
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||||
|
text_split_images = ""
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
f"{fake_token_around_image}" + f"<row_{n_h + 1}_col_{n_w + 1}>" + f"{image_token}" * image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += "\n"
|
||||||
|
|
||||||
|
text_split_images += (
|
||||||
|
f"\n{fake_token_around_image}"
|
||||||
|
+ f"{global_image_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
return text_split_images
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_image_token):
|
||||||
|
"""Prompt with expanded image tokens for a single image."""
|
||||||
|
return (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"{global_image_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_prompt_string(
|
||||||
|
image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_image_token
|
||||||
|
):
|
||||||
|
if image_rows == 0 and image_cols == 0:
|
||||||
|
return _prompt_single_image(
|
||||||
|
image_seq_len,
|
||||||
|
fake_token_around_image=fake_token_around_image,
|
||||||
|
image_token=image_token,
|
||||||
|
global_image_token=global_image_token,
|
||||||
|
)
|
||||||
|
return _prompt_split_image(
|
||||||
|
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMImagesKwargs(ImagesKwargs, total=False):
|
||||||
|
return_row_col_info: Optional[bool]
|
||||||
|
max_image_size: Optional[Dict[str, int]]
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
images_kwargs: SmolVLMImagesKwargs
|
||||||
|
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"add_special_tokens": True,
|
||||||
|
"padding": False,
|
||||||
|
"is_split_into_words": False,
|
||||||
|
},
|
||||||
|
"images_kwargs": {
|
||||||
|
"return_row_col_info": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMProcessor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a SmolVLM processor which wraps a LLama tokenizer and SmolVLM image processor into a single processor.
|
||||||
|
|
||||||
|
[`SmolVLMProcessor`] offers all the functionalities of [`SmolVLMImageProcessor`] and [`SmolVLMTokenizerFast`]. See
|
||||||
|
the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_processor (`SmolVLMImageProcessor`):
|
||||||
|
An instance of [`SmolVLMImageProcessor`]. The image processor is a required input.
|
||||||
|
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
||||||
|
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
|
||||||
|
image_seq_len (`int`, *optional*, defaults to 169):
|
||||||
|
The length of the image sequence i.e. the number of <image> tokens per image in the input.
|
||||||
|
This parameter is used to build the string from the input prompt and image tokens and should match the
|
||||||
|
value the model used. It is computed as: image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||||
|
in a chat into a tokenizable string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["image_seq_len", "chat_template"]
|
||||||
|
image_processor_class = "SmolVLMImageProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: str = None, **kwargs):
|
||||||
|
self.fake_image_token = getattr(tokenizer, "fake_image_token", "<fake_token_around_image>")
|
||||||
|
self.image_token = getattr(tokenizer, "image_token", "<image>")
|
||||||
|
self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "<end_of_utterance>")
|
||||||
|
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
|
||||||
|
self.image_seq_len = image_seq_len
|
||||||
|
|
||||||
|
self.video_size = image_processor.video_sampling["video_size"]
|
||||||
|
self.image_size = image_processor.size
|
||||||
|
|
||||||
|
self.do_image_splitting = image_processor.do_image_splitting
|
||||||
|
self.do_video_splitting = image_processor.video_sampling.get("do_image_splitting", False)
|
||||||
|
|
||||||
|
self.default_max_frames = image_processor.video_sampling["max_frames"]
|
||||||
|
self.default_fps = image_processor.video_sampling["fps"]
|
||||||
|
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
|
||||||
|
# self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
|
||||||
|
|
||||||
|
if not num2words:
|
||||||
|
raise ImportError(
|
||||||
|
"Package `num2words` is required to run SmolVLM processor. Install it with `pip install num2words`."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||||
|
|
||||||
|
def process_vision(self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None):
|
||||||
|
if text is not None:
|
||||||
|
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
||||||
|
|
||||||
|
n_images_in_images = [len(sublist) for sublist in images]
|
||||||
|
image_inputs = self.image_processor(
|
||||||
|
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs["images_kwargs"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
return None, image_inputs
|
||||||
|
|
||||||
|
if n_images_in_images != n_images_in_text:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
|
||||||
|
)
|
||||||
|
image_rows = image_inputs.pop("rows", [[0] * len(text)])
|
||||||
|
image_cols = image_inputs.pop("cols", [[0] * len(text)])
|
||||||
|
|
||||||
|
prompt_strings = []
|
||||||
|
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
||||||
|
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
||||||
|
image_prompt_strings = []
|
||||||
|
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||||||
|
image_prompt_string = get_image_prompt_string(
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
self.image_seq_len,
|
||||||
|
image_token=self.image_token,
|
||||||
|
fake_token_around_image=self.fake_image_token,
|
||||||
|
global_image_token=self.global_image_token,
|
||||||
|
)
|
||||||
|
image_prompt_strings.append(image_prompt_string)
|
||||||
|
|
||||||
|
split_sample = sample.split(self.image_token)
|
||||||
|
if len(split_sample) == 0:
|
||||||
|
raise ValueError("The image token should be present in the text.")
|
||||||
|
|
||||||
|
# Place in the image prompt strings where the image tokens are
|
||||||
|
sample = split_sample[0]
|
||||||
|
for i, image_prompt_string in enumerate(image_prompt_strings):
|
||||||
|
sample += image_prompt_string + split_sample[i + 1]
|
||||||
|
prompt_strings.append(sample)
|
||||||
|
|
||||||
|
return prompt_strings, image_inputs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
|
||||||
|
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
|
||||||
|
audio=None,
|
||||||
|
videos: VideoInput = None,
|
||||||
|
**kwargs: Unpack[SmolVLMProcessorKwargs],
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Processes the input prompts and returns a BatchEncoding.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import SmolVLMProcessor
|
||||||
|
>>> from transformers.image_utils import load_image
|
||||||
|
|
||||||
|
>>> processor = SmolVLMProcessor.from_pretrained("HuggingFaceM4/SmolVLM2-256M-Video-Instruct")
|
||||||
|
>>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
|
||||||
|
|
||||||
|
>>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
>>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
|
||||||
|
|
||||||
|
>>> image1, image2 = load_image(url1), load_image(url2)
|
||||||
|
>>> images = [[image1], [image2]]
|
||||||
|
|
||||||
|
>>> text = [
|
||||||
|
... "<image>In this image, we see",
|
||||||
|
... "bla bla bla<image>",
|
||||||
|
... ]
|
||||||
|
>>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
|
||||||
|
>>> input_ids = outputs.input_ids
|
||||||
|
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
|
||||||
|
>>> print(input_tokens)
|
||||||
|
['<|begin_of_text|><fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image> In this image, we see', '<|reserved_special_token_0|><|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>bla bla bla<fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image>']
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
|
||||||
|
text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
Wherever an image token, `<image>` is encountered it is expanded to
|
||||||
|
`<fake_token_around_image>` + `<row_x_col_y>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
|
||||||
|
return_tensors (`Union[str, TensorType]`, *optional*):
|
||||||
|
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
|
||||||
|
information.
|
||||||
|
"""
|
||||||
|
if text is None and images is None and videos is None:
|
||||||
|
raise ValueError("You must provide one of `text`, `images` or `videos'.")
|
||||||
|
|
||||||
|
if text is None and ((images is None) ^ (videos is not None)):
|
||||||
|
raise ValueError("You must specify exactly one of `images` or `videos`")
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
SmolVLMProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
n_images_in_text = sum([sample.count(self.image_token) for sample in text])
|
||||||
|
if n_images_in_text > 0 and (images is None and videos is None):
|
||||||
|
raise ValueError(f"We detected {n_images_in_text} tokens in the text but no images/videos were passed")
|
||||||
|
|
||||||
|
inputs = BatchFeature()
|
||||||
|
# Images and videos are mutually exclusive, so process one which is present
|
||||||
|
if images is not None:
|
||||||
|
images = make_nested_list_of_images(images)
|
||||||
|
text, vision_inputs = self.process_vision(
|
||||||
|
text,
|
||||||
|
images,
|
||||||
|
output_kwargs,
|
||||||
|
do_image_splitting=self.do_image_splitting,
|
||||||
|
image_processor_size=self.image_size,
|
||||||
|
)
|
||||||
|
inputs.update(vision_inputs)
|
||||||
|
elif videos is not None:
|
||||||
|
videos = make_batched_videos(videos)
|
||||||
|
text, vision_inputs = self.process_vision(
|
||||||
|
text,
|
||||||
|
videos,
|
||||||
|
output_kwargs,
|
||||||
|
do_image_splitting=self.do_image_splitting,
|
||||||
|
image_processor_size=self.video_size,
|
||||||
|
)
|
||||||
|
inputs.update(vision_inputs)
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||||
|
inputs.update(text_inputs)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process_messages_for_chat_template(
|
||||||
|
self,
|
||||||
|
conversations: List[List[Dict[str, str]]],
|
||||||
|
batch_images: List[ImageInput],
|
||||||
|
batch_videos: List[VideoInput],
|
||||||
|
batch_video_metadata: List[List[Dict[str, any]]],
|
||||||
|
**chat_template_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Used within `apply_chat_template` when a model has special way to process conversation history. For example,
|
||||||
|
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
|
||||||
|
were sampled. This information cannot be accessed before the video is loaded.
|
||||||
|
For most models it is a no-op, must be overriden by model processors which require special processing.
|
||||||
|
Args:
|
||||||
|
conversation (`List[Dict, str, str]`):
|
||||||
|
The conversation to process. Always comes in batched format.
|
||||||
|
batch_images (`List[List[ImageInput]]`):
|
||||||
|
Batch of images that were loaded from url/path defined in the conversation. The images
|
||||||
|
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
|
||||||
|
per batch.
|
||||||
|
batch_videos (`List[List[ImageInput]]`):
|
||||||
|
Batch of videos that were loaded from url/path defined in the conversation. The videos
|
||||||
|
are ordered in the same way as in the conversation. Comes in nested list format, one list of 4D video arrays
|
||||||
|
per batch.
|
||||||
|
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
|
||||||
|
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
|
||||||
|
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
|
||||||
|
per batch.
|
||||||
|
"""
|
||||||
|
# We don't want to modify in-place the messages passed by user
|
||||||
|
# The user might want to add new turn on conv and continue generation
|
||||||
|
conversations = copy.deepcopy(conversations)
|
||||||
|
batch_num_frames, batch_timestamps = [], []
|
||||||
|
for metadata_list, video_list in zip(batch_video_metadata, batch_videos):
|
||||||
|
for metadata, video in zip(metadata_list, video_list):
|
||||||
|
duration_sec = getattr(metadata, "duration")
|
||||||
|
frames_idx = getattr(metadata, "frames_indices")
|
||||||
|
fps = getattr(metadata, "fps")
|
||||||
|
|
||||||
|
timestamps = []
|
||||||
|
for idx, frame_np in zip(frames_idx, video):
|
||||||
|
sec = idx / fps
|
||||||
|
mm = int(sec // 60)
|
||||||
|
ss = int(sec % 60)
|
||||||
|
timestamps.append(f"{mm:02d}:{ss:02d}")
|
||||||
|
batch_timestamps.append(timestamps)
|
||||||
|
batch_num_frames.append(len(video))
|
||||||
|
|
||||||
|
for conversation in conversations:
|
||||||
|
# For each message, scan content for {"type": "video"}
|
||||||
|
for msg in conversation:
|
||||||
|
if "content" not in msg:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_content = []
|
||||||
|
for block in msg["content"]:
|
||||||
|
if block.get("type") == "video":
|
||||||
|
curr_timestamps = batch_timestamps.pop(0)
|
||||||
|
curr_num_frames = batch_num_frames.pop(0)
|
||||||
|
|
||||||
|
# Build the video intro texts
|
||||||
|
td = timedelta(seconds=int(duration_sec))
|
||||||
|
new_content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": DEFAULT_VIDEO_INTRO.format(
|
||||||
|
frame_count=num2words(curr_num_frames), video_duration=str(td)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Insert per-frame lines: "Frame from {timestamp}:", then an "image" block
|
||||||
|
for i, ts in enumerate(curr_timestamps):
|
||||||
|
new_content.append({"type": "text", "text": FRAME_TIMESTAMP_MESSAGE.format(timestamp=ts)})
|
||||||
|
new_content.append({"type": "image"})
|
||||||
|
|
||||||
|
# 3) Optionally add an outro (e.g. "Now answer the question:")
|
||||||
|
new_content.append({"type": "text", "text": DEFAULT_MEDIA_OUTTRO})
|
||||||
|
# Do NOT add the original block => we skip it (since we've replaced it)
|
||||||
|
else:
|
||||||
|
# keep original block
|
||||||
|
new_content.append(block)
|
||||||
|
|
||||||
|
# update the content
|
||||||
|
msg["content"] = new_content
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
return batched_decode_output
|
||||||
|
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
decode_output = self.tokenizer.decode(*args, **kwargs)
|
||||||
|
return decode_output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
||||||
|
|
||||||
|
# Add model-specific video sampling method when applying the template
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation,
|
||||||
|
max_frames=None,
|
||||||
|
target_fps=None,
|
||||||
|
skip_secs=1,
|
||||||
|
video_load_backend="pyav",
|
||||||
|
sample_indices_fn=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
max_frames = self.default_max_frames if max_frames is None else max_frames
|
||||||
|
target_fps = self.default_fps if target_fps is None else target_fps
|
||||||
|
|
||||||
|
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||||
|
return smolvlm_sample_indices_fn(
|
||||||
|
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# word of caution- we are blindly overriding a callable kwarg here.
|
||||||
|
# typed kwargs would be a way to avoid that @molbap
|
||||||
|
if not sample_indices_fn:
|
||||||
|
sample_indices_fn = sample_indices_fn_func
|
||||||
|
return super().apply_chat_template(
|
||||||
|
conversation, video_load_backend=video_load_backend, sample_indices_fn=sample_indices_fn, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["SmolVLMProcessor"]
|
90
src/transformers/models/smolvlm/video_processing_smolvlm.py
Normal file
90
src/transformers/models/smolvlm/video_processing_smolvlm.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Make sure these are imported from your library
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
||||||
|
DEFAULT_VIDEO_INTRO = (
|
||||||
|
"You are provided the following series of {frame_count} frames from a {video_duration} [H:MM:SS] video.\n"
|
||||||
|
)
|
||||||
|
DEFAULT_MEDIA_OUTTRO = "\n\n"
|
||||||
|
FRAME_TIMESTAMP_MESSAGE = "\nFrame from {timestamp}:"
|
||||||
|
|
||||||
|
|
||||||
|
def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0):
|
||||||
|
"""
|
||||||
|
Example sampling function which:
|
||||||
|
- Uses `max_frames` (if provided) or calculates it from `fps` and metadata.
|
||||||
|
- Applies a basic center-skip if fewer frames than available, otherwise
|
||||||
|
optionally skips `skip_secs` from both the start and end.
|
||||||
|
- Uniformly samples the desired number of frames between the start and end indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_frames (`int`):
|
||||||
|
Maximum number of frames to sample.
|
||||||
|
target_fps (`int`):
|
||||||
|
Target frames to sample per second.
|
||||||
|
metadata (`dict`):
|
||||||
|
Contains video metadata such as "n_frames" and "video_fps".
|
||||||
|
skip_secs (`float`, *optional*, defaults to 1.0):
|
||||||
|
Number of seconds to skip from the start and end if the video is long enough.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray:
|
||||||
|
An array of unique frame indices to sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_num_frames = getattr(metadata, "total_num_frames", 0)
|
||||||
|
if total_num_frames <= 0:
|
||||||
|
raise ValueError(f"Invalid total_num_frames={total_num_frames} in metadata.")
|
||||||
|
|
||||||
|
native_fps = getattr(metadata, "fps", 30.0)
|
||||||
|
duration_seconds = getattr(metadata, "duration", 0)
|
||||||
|
|
||||||
|
if duration_seconds <= 0:
|
||||||
|
raise ValueError(f"Invalid duration_seconds={duration_seconds} in metadata.")
|
||||||
|
|
||||||
|
# Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0
|
||||||
|
estimated_frames = int(round(target_fps * duration_seconds))
|
||||||
|
|
||||||
|
# Step 2) desired_frames
|
||||||
|
desired_frames = min(estimated_frames, max_frames)
|
||||||
|
if desired_frames < 1:
|
||||||
|
desired_frames = 1
|
||||||
|
|
||||||
|
# Step 3) center skip logic
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = total_num_frames - 1
|
||||||
|
|
||||||
|
if skip_secs > 0 and (duration_seconds - 2 * skip_secs) > (max_frames * target_fps):
|
||||||
|
start_idx = int(skip_secs * native_fps)
|
||||||
|
end_idx = int(total_num_frames - skip_secs * native_fps)
|
||||||
|
|
||||||
|
start_idx = max(0, start_idx)
|
||||||
|
end_idx = min(end_idx, total_num_frames - 1)
|
||||||
|
if start_idx >= end_idx:
|
||||||
|
start_idx, end_idx = 0, total_num_frames - 1
|
||||||
|
|
||||||
|
indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
|
||||||
|
indices = np.unique(indices)
|
||||||
|
|
||||||
|
return indices
|
@ -165,6 +165,7 @@ from .import_utils import (
|
|||||||
is_natten_available,
|
is_natten_available,
|
||||||
is_ninja_available,
|
is_ninja_available,
|
||||||
is_nltk_available,
|
is_nltk_available,
|
||||||
|
is_num2words_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_openai_available,
|
is_openai_available,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
|
@ -8849,6 +8849,48 @@ class SiglipVisionModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMForConditionalGeneration(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionTransformer(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class SpeechEncoderDecoderModel(metaclass=DummyObject):
|
class SpeechEncoderDecoderModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -639,6 +639,13 @@ class SiglipImageProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["vision"])
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMImageProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
class SuperGlueImageProcessor(metaclass=DummyObject):
|
class SuperGlueImageProcessor(metaclass=DummyObject):
|
||||||
_backends = ["vision"]
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
@ -196,6 +196,7 @@ _torchao_available = _is_package_available("torchao")
|
|||||||
_torchdistx_available = _is_package_available("torchdistx")
|
_torchdistx_available = _is_package_available("torchdistx")
|
||||||
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
|
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
|
||||||
_mlx_available = _is_package_available("mlx")
|
_mlx_available = _is_package_available("mlx")
|
||||||
|
_num2words_available = _is_package_available("num2words")
|
||||||
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
|
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
|
||||||
_tiktoken_available = _is_package_available("tiktoken")
|
_tiktoken_available = _is_package_available("tiktoken")
|
||||||
_blobfile_available = _is_package_available("blobfile")
|
_blobfile_available = _is_package_available("blobfile")
|
||||||
@ -1280,6 +1281,10 @@ def is_mlx_available():
|
|||||||
return _mlx_available
|
return _mlx_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_num2words_available():
|
||||||
|
return _num2words_available
|
||||||
|
|
||||||
|
|
||||||
def is_tiktoken_available():
|
def is_tiktoken_available():
|
||||||
return _tiktoken_available and _blobfile_available
|
return _tiktoken_available and _blobfile_available
|
||||||
|
|
||||||
|
@ -193,6 +193,10 @@ class Idefics3ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in idefics3 models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# We need to override as we need to prepare such that the image token is the last token
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -377,6 +381,10 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_eager_matches_sdpa_generate(self):
|
def test_eager_matches_sdpa_generate(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in Idefics3 models end-to-end")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# We need to override as we need to prepare such that the image token is the last token
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
0
tests/models/smolvlm/__init__.py
Normal file
0
tests/models/smolvlm/__init__.py
Normal file
284
tests/models/smolvlm/test_image_processing_smolvlm.py
Normal file
284
tests/models/smolvlm/test_image_processing_smolvlm.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.image_utils import PILImageResampling
|
||||||
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_image_processing_common import ImageProcessingTestMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import SmolVLMImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMImageProcessingTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=7,
|
||||||
|
num_channels=3,
|
||||||
|
num_images=1,
|
||||||
|
image_size=18,
|
||||||
|
min_resolution=30,
|
||||||
|
max_resolution=40,
|
||||||
|
do_resize=True,
|
||||||
|
size=None,
|
||||||
|
max_image_size=None,
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=1 / 255,
|
||||||
|
do_normalize=True,
|
||||||
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
|
image_std=[0.5, 0.5, 0.5],
|
||||||
|
do_convert_rgb=True,
|
||||||
|
do_pad=True,
|
||||||
|
do_image_splitting=True,
|
||||||
|
resample=PILImageResampling.LANCZOS,
|
||||||
|
):
|
||||||
|
self.size = size if size is not None else {"longest_edge": max_resolution}
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.num_images = num_images
|
||||||
|
self.image_size = image_size
|
||||||
|
self.min_resolution = min_resolution
|
||||||
|
self.max_resolution = max_resolution
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.resample = resample
|
||||||
|
self.do_image_splitting = do_image_splitting
|
||||||
|
self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 20}
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
self.do_pad = do_pad
|
||||||
|
|
||||||
|
def prepare_image_processor_dict(self):
|
||||||
|
return {
|
||||||
|
"do_convert_rgb": self.do_convert_rgb,
|
||||||
|
"do_resize": self.do_resize,
|
||||||
|
"size": self.size,
|
||||||
|
"max_image_size": self.max_image_size,
|
||||||
|
"do_rescale": self.do_rescale,
|
||||||
|
"rescale_factor": self.rescale_factor,
|
||||||
|
"do_normalize": self.do_normalize,
|
||||||
|
"image_mean": self.image_mean,
|
||||||
|
"image_std": self.image_std,
|
||||||
|
"do_pad": self.do_pad,
|
||||||
|
"do_image_splitting": self.do_image_splitting,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_expected_values(self, image_inputs, batched=False):
|
||||||
|
"""
|
||||||
|
This function computes the expected height and width when providing images to SmolVLMImageProcessor,
|
||||||
|
assuming do_resize is set to True. The expected size in that case the max image size.
|
||||||
|
"""
|
||||||
|
return self.max_image_size["longest_edge"], self.max_image_size["longest_edge"]
|
||||||
|
|
||||||
|
def expected_output_image_shape(self, images):
|
||||||
|
height, width = self.get_expected_values(images, batched=True)
|
||||||
|
effective_nb_images = (
|
||||||
|
self.num_images * 5 if self.do_image_splitting else 1
|
||||||
|
) # 5 is a squared image divided into 4 + global image resized
|
||||||
|
return effective_nb_images, self.num_channels, height, width
|
||||||
|
|
||||||
|
def prepare_image_inputs(
|
||||||
|
self,
|
||||||
|
batch_size=None,
|
||||||
|
min_resolution=None,
|
||||||
|
max_resolution=None,
|
||||||
|
num_channels=None,
|
||||||
|
num_images=None,
|
||||||
|
size_divisor=None,
|
||||||
|
equal_resolution=False,
|
||||||
|
numpify=False,
|
||||||
|
torchify=False,
|
||||||
|
):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
|
||||||
|
One can specify whether the images are of the same resolution or not.
|
||||||
|
"""
|
||||||
|
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
|
||||||
|
|
||||||
|
batch_size = batch_size if batch_size is not None else self.batch_size
|
||||||
|
min_resolution = min_resolution if min_resolution is not None else self.min_resolution
|
||||||
|
max_resolution = max_resolution if max_resolution is not None else self.max_resolution
|
||||||
|
num_channels = num_channels if num_channels is not None else self.num_channels
|
||||||
|
num_images = num_images if num_images is not None else self.num_images
|
||||||
|
|
||||||
|
images_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
images = []
|
||||||
|
for j in range(num_images):
|
||||||
|
if equal_resolution:
|
||||||
|
width = height = max_resolution
|
||||||
|
else:
|
||||||
|
# To avoid getting image width/height 0
|
||||||
|
if size_divisor is not None:
|
||||||
|
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
|
||||||
|
min_resolution = max(size_divisor, min_resolution)
|
||||||
|
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
|
||||||
|
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
|
||||||
|
images_list.append(images)
|
||||||
|
|
||||||
|
if not numpify and not torchify:
|
||||||
|
# PIL expects the channel dimension as last dimension
|
||||||
|
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
if torchify:
|
||||||
|
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
if numpify:
|
||||||
|
# Numpy images are typically in channels last format
|
||||||
|
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
return images_list
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
|
image_processing_class = SmolVLMImageProcessor if is_vision_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.image_processor_tester = SmolVLMImageProcessingTester(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_processor_dict(self):
|
||||||
|
return self.image_processor_tester.prepare_image_processor_dict()
|
||||||
|
|
||||||
|
def test_image_processor_properties(self):
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "size"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "resample"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "max_image_size"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||||
|
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||||
|
|
||||||
|
def test_call_numpy(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random numpy tensors
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||||
|
for sample_images in image_inputs:
|
||||||
|
for image in sample_images:
|
||||||
|
self.assertIsInstance(image, np.ndarray)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||||
|
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||||
|
self.assertEqual(
|
||||||
|
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_call_numpy_4_channels(self):
|
||||||
|
# SmolVLM always processes images as RGB, so it always returns images with 3 channels
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processor_dict = self.image_processor_dict
|
||||||
|
image_processing = self.image_processing_class(**image_processor_dict)
|
||||||
|
# create random numpy tensors
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||||
|
|
||||||
|
for sample_images in image_inputs:
|
||||||
|
for image in sample_images:
|
||||||
|
self.assertIsInstance(image, np.ndarray)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||||
|
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||||
|
self.assertEqual(
|
||||||
|
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_call_pil(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PIL images
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||||
|
for images in image_inputs:
|
||||||
|
for image in images:
|
||||||
|
self.assertIsInstance(image, Image.Image)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||||
|
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||||
|
self.assertEqual(
|
||||||
|
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_call_pytorch(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PyTorch tensors
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||||
|
|
||||||
|
for images in image_inputs:
|
||||||
|
for image in images:
|
||||||
|
self.assertIsInstance(image, torch.Tensor)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||||
|
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||||
|
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||||
|
self.assertEqual(
|
||||||
|
tuple(encoded_images.shape),
|
||||||
|
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||||
|
)
|
591
tests/models/smolvlm/test_modeling_smolvlm.py
Normal file
591
tests/models/smolvlm/test_modeling_smolvlm.py
Normal file
@ -0,0 +1,591 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch SmolVLM model."""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
|
require_torch,
|
||||||
|
require_torch_sdpa,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
SmolVLMConfig,
|
||||||
|
SmolVLMForConditionalGeneration,
|
||||||
|
SmolVLMModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLMVisionText2TextModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
is_training=True,
|
||||||
|
batch_size=2,
|
||||||
|
scale_factor=2,
|
||||||
|
num_images=2,
|
||||||
|
vision_config={
|
||||||
|
"image_size": 16,
|
||||||
|
"patch_size": 4,
|
||||||
|
"hidden_size": 32,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"intermediate_size": 32,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
},
|
||||||
|
text_config={
|
||||||
|
"vocab_size": 100,
|
||||||
|
"hidden_size": 64,
|
||||||
|
"intermediate_size": 56,
|
||||||
|
"num_hidden_layers": 3,
|
||||||
|
"num_attention_heads": 2,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"max_position_embeddings": 256,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"pad_token_id": 2,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"image_token_id": 57,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"sliding_window": 32,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
},
|
||||||
|
use_cache=False,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
image_token_id=57,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_images = num_images
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.seq_length = (
|
||||||
|
int(((vision_config["image_size"] // vision_config["patch_size"]) ** 2) / (self.scale_factor**2))
|
||||||
|
* self.num_images
|
||||||
|
)
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
# Hack - add properties here so use common tests
|
||||||
|
self.vocab_size = text_config["vocab_size"]
|
||||||
|
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||||
|
self.num_attention_heads = text_config["num_attention_heads"]
|
||||||
|
self.hidden_size = text_config["hidden_size"]
|
||||||
|
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return SmolVLMConfig(
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
image_token_id=self.image_token_id,
|
||||||
|
tie_word_embeddings=self.tie_word_embeddings,
|
||||||
|
vision_config=self.vision_config,
|
||||||
|
text_config=self.text_config,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
scale_factor=self.scale_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor(
|
||||||
|
[
|
||||||
|
self.batch_size,
|
||||||
|
self.num_images,
|
||||||
|
3, # SmolVLMImageProcessor always generates RGB pixel values
|
||||||
|
self.vision_config["image_size"],
|
||||||
|
self.vision_config["image_size"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 1
|
||||||
|
|
||||||
|
# For simplicity just set the last n tokens to the image token
|
||||||
|
n_image_tokens_per_batch = self.seq_length
|
||||||
|
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SmolVLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Model tester for `SmolVLM`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (SmolVLMModel,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_torchscript = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = True
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = SmolVLMVisionText2TextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(
|
||||||
|
self, config_class=SmolVLMConfig, has_text_modality=False, common_properties=["image_token_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||||
|
def test_inputs_embeds():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||||
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Model does not support padding right")
|
||||||
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in SmolVLM models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in SmolVLM models")
|
||||||
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if self.model_tester.is_training is False:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
# Retrieve the embeddings and clone theme
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||||
|
cloned_embeddings = model_embed.weight.clone()
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
|
||||||
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
|
||||||
|
n_images = self.model_tester.num_images * self.model_tester.seq_length
|
||||||
|
model.image_token_id = model_vocab_size - 15 - 1
|
||||||
|
inputs_dict["input_ids"][:, -n_images:] = model.image_token_id
|
||||||
|
|
||||||
|
# make sure that decoder_input_ids are resized as well
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||||
|
self.assertTrue(model.config.text_config.vocab_size + 10, model_vocab_size)
|
||||||
|
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
|
self.assertTrue(model_embed.weight.shape[0], model.config.text_config.vocab_size)
|
||||||
|
self.assertTrue(model.config.text_config.vocab_size, model.vocab_size)
|
||||||
|
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
|
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
|
||||||
|
target_dimension = 128
|
||||||
|
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0], target_dimension)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||||
|
):
|
||||||
|
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||||
|
|
||||||
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
original_config.tie_word_embeddings = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
|
||||||
|
# if no output embeddings -> leave test
|
||||||
|
if model.get_output_embeddings() is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
|
||||||
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
|
||||||
|
n_images = self.model_tester.num_images * self.model_tester.seq_length
|
||||||
|
model.image_token_id = model_vocab_size - 15 - 1
|
||||||
|
inputs_dict["input_ids"][:, -n_images:] = model.image_token_id
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Model tester for `SmolVLMForConditionalGeneration`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (SmolVLMForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (SmolVLMForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = {"image-text-to-text": SmolVLMForConditionalGeneration} if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = True
|
||||||
|
test_head_masking = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = SmolVLMVisionText2TextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=SmolVLMConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||||
|
def test_inputs_embeds():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Model does not support padding right")
|
||||||
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
|
||||||
|
)
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||||
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
|
)
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Unsupported")
|
||||||
|
def test_generate_from_inputs_embeds_0_greedy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Unsupported")
|
||||||
|
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Unsupported")
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in SmolVLM models")
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Compile not yet supported in SmolVLM models")
|
||||||
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@require_torch_sdpa
|
||||||
|
@slow
|
||||||
|
@unittest.skip(
|
||||||
|
reason="SmolVLM doesn't support SDPA for all backbones, vision backbones has only eager/FA2 attention"
|
||||||
|
)
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
|
||||||
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
# Retrieve the embeddings and clone theme
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||||
|
cloned_embeddings = model_embed.weight.clone()
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
|
||||||
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
|
||||||
|
n_images = self.model_tester.num_images * self.model_tester.seq_length
|
||||||
|
model.model.image_token_id = model_vocab_size - 15 - 1
|
||||||
|
inputs_dict["input_ids"][:, -n_images:] = model.model.image_token_id
|
||||||
|
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||||
|
self.assertTrue(model.config.text_config.vocab_size + 10, model_vocab_size)
|
||||||
|
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
|
self.assertTrue(model_embed.weight.shape[0], model.config.text_config.vocab_size)
|
||||||
|
self.assertTrue(model.config.text_config.vocab_size, model.vocab_size)
|
||||||
|
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
|
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
|
||||||
|
target_dimension = 128
|
||||||
|
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0], target_dimension)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||||
|
):
|
||||||
|
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||||
|
|
||||||
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
original_config.tie_word_embeddings = False
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
|
||||||
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
|
||||||
|
n_images = self.model_tester.num_images * self.model_tester.seq_length
|
||||||
|
model.model.image_token_id = model_vocab_size - 15 - 1
|
||||||
|
inputs_dict["input_ids"][:, -n_images:] = model.model.image_token_id
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
|
||||||
|
self.image1 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.image2 = Image.open(
|
||||||
|
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
|
||||||
|
)
|
||||||
|
self.image3 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
# TODO (Orr?) this is a dummy test to check if the model generates things that make sense.
|
||||||
|
# Needs to be expanded to a tiny video
|
||||||
|
def test_integration_test(self):
|
||||||
|
model = SmolVLMForConditionalGeneration.from_pretrained(
|
||||||
|
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create inputs
|
||||||
|
text = "<image>In this image, we see"
|
||||||
|
images = self.image1
|
||||||
|
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
|
||||||
|
inputs.to(device=torch_device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=9)
|
||||||
|
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_generated_text = "\n\n\n\nIn this image, we see a view of the Statue of Liberty and the"
|
||||||
|
self.assertEqual(generated_texts[0], expected_generated_text)
|
655
tests/models/smolvlm/test_processor_smolvlm.py
Normal file
655
tests/models/smolvlm/test_processor_smolvlm.py
Normal file
@ -0,0 +1,655 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from transformers import SmolVLMProcessor
|
||||||
|
from transformers.models.auto.processing_auto import AutoProcessor
|
||||||
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
processor_class = SmolVLMProcessor
|
||||||
|
videos_input_name = "pixel_values"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tmpdirname = tempfile.mkdtemp()
|
||||||
|
processor = SmolVLMProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct", image_seq_len=2)
|
||||||
|
processor.save_pretrained(cls.tmpdirname)
|
||||||
|
cls.image1 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.image2 = Image.open(
|
||||||
|
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
|
||||||
|
)
|
||||||
|
cls.image3 = Image.open(
|
||||||
|
BytesIO(
|
||||||
|
requests.get(
|
||||||
|
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.bos_token = processor.tokenizer.bos_token
|
||||||
|
cls.image_token = processor.image_token
|
||||||
|
cls.fake_image_token = processor.fake_image_token
|
||||||
|
cls.global_img_token = processor.global_image_token
|
||||||
|
|
||||||
|
cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token)
|
||||||
|
cls.image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.image_token)
|
||||||
|
cls.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.fake_image_token)
|
||||||
|
cls.global_img_tokens_id = processor.tokenizer(cls.global_img_token, add_special_tokens=False)["input_ids"]
|
||||||
|
cls.padding_token_id = processor.tokenizer.pad_token_id
|
||||||
|
cls.image_seq_len = processor.image_seq_len
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def get_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def prepare_processor_dict(self):
|
||||||
|
return {
|
||||||
|
"image_seq_len": self.image_seq_len,
|
||||||
|
"chat_template": "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_split_image_expected_tokens(self, processor, image_rows, image_cols):
|
||||||
|
text_split_images = []
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
[self.fake_image_token_id]
|
||||||
|
+ processor.tokenizer(f"<row_{n_h + 1}_col_{n_w + 1}>", add_special_tokens=False)["input_ids"]
|
||||||
|
+ [self.image_token_id] * self.image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += processor.tokenizer("\n", add_special_tokens=False)["input_ids"]
|
||||||
|
text_split_images = text_split_images[:-1] # remove last newline
|
||||||
|
# add double newline, as it gets its own token
|
||||||
|
text_split_images += processor.tokenizer("\n\n", add_special_tokens=False)["input_ids"]
|
||||||
|
text_split_images += (
|
||||||
|
[self.fake_image_token_id]
|
||||||
|
+ self.global_img_tokens_id
|
||||||
|
+ [self.image_token_id] * self.image_seq_len
|
||||||
|
+ [self.fake_image_token_id]
|
||||||
|
)
|
||||||
|
return text_split_images
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
shutil.rmtree(cls.tmpdirname)
|
||||||
|
|
||||||
|
def test_process_interleaved_images_prompts_no_image_splitting(self):
|
||||||
|
processor_components = self.prepare_components()
|
||||||
|
processor_components["tokenizer"] = self.get_component("tokenizer", padding_side="left")
|
||||||
|
processor_components["image_processor"] = self.get_component("image_processor", do_image_splitting=False)
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
|
|
||||||
|
# Test that a single image is processed correctly
|
||||||
|
inputs = processor(images=self.image1)
|
||||||
|
image1_expected_size = (512, 512)
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 1, 3, *image1_expected_size))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 1, *image1_expected_size))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test a single sample with image and text
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str = "In this image, we see"
|
||||||
|
text = image_str + text_str
|
||||||
|
inputs = processor(text=text, images=self.image1)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
|
||||||
|
expected_input_ids = [[self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
|
||||||
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
|
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 1, 3, *image1_expected_size))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 1, *image1_expected_size))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test that batch is correctly processed
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str_1 = "In this image, we see"
|
||||||
|
text_str_2 = "In this image, we see"
|
||||||
|
|
||||||
|
text = [
|
||||||
|
image_str + text_str_1,
|
||||||
|
image_str + image_str + text_str_2,
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
|
||||||
|
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
|
||||||
|
image_tokens = [self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
|
||||||
|
expected_input_ids_1 = image_tokens + tokenized_sentence_1["input_ids"]
|
||||||
|
expected_input_ids_2 = 2 * image_tokens + tokenized_sentence_2["input_ids"]
|
||||||
|
# Pad the first input to match the second input
|
||||||
|
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
|
||||||
|
padded_expected_input_ids_1 = [self.padding_token_id] * pad_len + expected_input_ids_1
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["input_ids"], [padded_expected_input_ids_1, expected_input_ids_2]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["attention_mask"],
|
||||||
|
[[0] * pad_len + [1] * len(expected_input_ids_1), [1] * len(expected_input_ids_2)]
|
||||||
|
)
|
||||||
|
self.assertEqual(np.array(inputs['pixel_values']).shape, (2, 2, 3, 512, 512))
|
||||||
|
self.assertEqual(np.array(inputs['pixel_attention_mask']).shape, (2, 2, 512, 512))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_process_interleaved_images_prompts_image_splitting(self):
|
||||||
|
processor_components = self.prepare_components()
|
||||||
|
processor_components["tokenizer"] = self.get_component("tokenizer", padding_side="left")
|
||||||
|
processor_components["image_processor"] = self.get_component("image_processor", do_image_splitting=True)
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
|
|
||||||
|
# Test that a single image is processed correctly
|
||||||
|
inputs = processor(images=self.image1)
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 13, 3, 512, 512))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 13, 512, 512))
|
||||||
|
# fmt: on
|
||||||
|
self.maxDiff = None
|
||||||
|
|
||||||
|
# Test a single sample with image and text
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str = "In this image, we see"
|
||||||
|
text = image_str + text_str
|
||||||
|
inputs = processor(text=text, images=self.image1)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
|
||||||
|
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
|
||||||
|
expected_input_ids_1 = [split_image1_tokens + tokenized_sentence["input_ids"]]
|
||||||
|
self.assertEqual(inputs["input_ids"], expected_input_ids_1)
|
||||||
|
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids_1[0])])
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 13, 3, 512, 512))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 13, 512, 512))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Test that batch is correctly processed
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str_1 = "In this image, we see"
|
||||||
|
text_str_2 = "bla, bla"
|
||||||
|
|
||||||
|
text = [
|
||||||
|
image_str + text_str_1,
|
||||||
|
text_str_2 + image_str + image_str,
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
|
||||||
|
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
|
||||||
|
|
||||||
|
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
|
||||||
|
split_image2_tokens = self.get_split_image_expected_tokens(processor, 4, 4)
|
||||||
|
split_image3_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
|
||||||
|
expected_input_ids_1 = split_image1_tokens + tokenized_sentence_1["input_ids"]
|
||||||
|
expected_input_ids_2 = tokenized_sentence_2["input_ids"] + split_image2_tokens + split_image3_tokens
|
||||||
|
# Pad the first input to match the second input
|
||||||
|
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
|
||||||
|
padded_expected_input_ids_1 = [self.padding_token_id] * pad_len + expected_input_ids_1
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["input_ids"], [padded_expected_input_ids_1, expected_input_ids_2]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["attention_mask"],
|
||||||
|
[[0] * pad_len + [1] * len(expected_input_ids_1), [1] * len(expected_input_ids_2)]
|
||||||
|
)
|
||||||
|
self.assertEqual(np.array(inputs['pixel_values']).shape, (2, 30, 3, 512, 512))
|
||||||
|
self.assertEqual(np.array(inputs['pixel_attention_mask']).shape, (2, 30, 512, 512))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_add_special_tokens_processor(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str = "In this image, we see"
|
||||||
|
text = text_str + image_str
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
inputs = processor(text=text, images=self.image1, add_special_tokens=False)
|
||||||
|
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
|
||||||
|
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
|
||||||
|
expected_input_ids = [tokenized_sentence["input_ids"] + split_image1_tokens]
|
||||||
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=self.image1)
|
||||||
|
expected_input_ids = [tokenized_sentence["input_ids"] + split_image1_tokens]
|
||||||
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@unittest.skip(reason="from @molbap @zucchini-nlp, passing non-nested images is error-prone and not recommended")
|
||||||
|
def test_non_nested_images_with_batched_text(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
processor.image_processor.do_image_splitting = False
|
||||||
|
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str_1 = "In this image, we see"
|
||||||
|
text_str_2 = "In this image, we see"
|
||||||
|
|
||||||
|
text = [
|
||||||
|
image_str + text_str_1,
|
||||||
|
image_str + image_str + text_str_2,
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 512, 512))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 512, 512))
|
||||||
|
|
||||||
|
# Copied from tests.models.idefics2.test_processor_idefics2.Idefics2ProcessorTest.test_process_interleaved_images_prompts_image_error
|
||||||
|
def test_process_interleaved_images_prompts_image_error(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.<image>",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2, self.image3]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
def test_apply_chat_template(self):
|
||||||
|
# Message contains content which a mix of lists with images and image urls and string
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What do these images show?"},
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "image"},
|
||||||
|
"What do these images show?",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "And who is that?"}]},
|
||||||
|
]
|
||||||
|
processor = self.get_processor()
|
||||||
|
# Make short sequence length to test that the fake tokens are added correctly
|
||||||
|
rendered = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
expected_rendered = (
|
||||||
|
"<|im_start|>User: What do these images show?<image><image><end_of_utterance>\n"
|
||||||
|
"Assistant: The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.<end_of_utterance>\n"
|
||||||
|
"User: And who is that?<end_of_utterance>\n"
|
||||||
|
"Assistant:"
|
||||||
|
)
|
||||||
|
self.assertEqual(rendered, expected_rendered)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Broken from common. Fixing TODO @zucchini-nlp @molbap")
|
||||||
|
def test_chat_template_video_special_processing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_av
|
||||||
|
def test_chat_template_video(self):
|
||||||
|
# overriden because SmolVLM has special preprocessing for videos
|
||||||
|
processor = self.get_processor()
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "What is shown in this video?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
num_frames = 3
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
# SmolVLM doesn't sample `num_frames` exactly, by uses other sampling method
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 10)
|
||||||
|
|
||||||
|
# Load with `video_fps` arg
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
# SmolVLM doesn't sample 1 frame per second exactly, by uses other sampling method
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
|
||||||
|
|
||||||
|
# NOTE: the last assert checks are removed
|
||||||
|
# Loading video as a list of frames (i.e. images) is not supported in SmolVLM
|
||||||
|
|
||||||
|
# Override as SmolVLMProcessor needs image tokens in prompts
|
||||||
|
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||||
|
if batch_size is None:
|
||||||
|
return "lower newer <image>"
|
||||||
|
|
||||||
|
if batch_size < 1:
|
||||||
|
raise ValueError("batch_size must be greater than 0")
|
||||||
|
|
||||||
|
if batch_size == 1:
|
||||||
|
return ["lower newer <image>"]
|
||||||
|
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
|
||||||
|
batch_size - 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override tests as inputs_ids padded dimension is the second one but not the last one
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer", max_length=30)
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
input_str = self.prepare_text_inputs()
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=30)
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 30)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_structured_kwargs_nested(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer")
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
|
input_str = self.prepare_text_inputs()
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
# Define the kwargs for each modality
|
||||||
|
inputs = processor(
|
||||||
|
text=input_str,
|
||||||
|
images=image_input,
|
||||||
|
common_kwargs={"return_tensors": "pt"},
|
||||||
|
images_kwargs={"max_image_size": {"longest_edge": 32}},
|
||||||
|
text_kwargs={"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
|
||||||
|
)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[3], 32)
|
||||||
|
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 120)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_structured_kwargs_nested_from_dict(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer")
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
input_str = self.prepare_text_inputs()
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
# Define the kwargs for each modality
|
||||||
|
all_kwargs = {
|
||||||
|
"common_kwargs": {"return_tensors": "pt"},
|
||||||
|
"images_kwargs": {"max_image_size": {"longest_edge": 32}},
|
||||||
|
"text_kwargs": {"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[3], 32)
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 120)
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer", max_length=30)
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
input_str = self.prepare_text_inputs()
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 30)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_unstructured_kwargs_batched(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer")
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
|
input_str = self.prepare_text_inputs(batch_size=2)
|
||||||
|
image_input = self.prepare_image_inputs(batch_size=2)
|
||||||
|
image_input = [[image_input[0]], [image_input[1]]]
|
||||||
|
inputs = processor(
|
||||||
|
text=input_str,
|
||||||
|
images=image_input,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
max_length=76,
|
||||||
|
truncation=True,
|
||||||
|
max_image_size={"longest_edge": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[2], 3)
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[3], 30)
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_unstructured_kwargs(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer")
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
|
input_str = self.prepare_text_inputs()
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
inputs = processor(
|
||||||
|
text=input_str,
|
||||||
|
images=image_input,
|
||||||
|
return_tensors="pt",
|
||||||
|
max_image_size={"longest_edge": 32},
|
||||||
|
padding="max_length",
|
||||||
|
max_length=120,
|
||||||
|
truncation="longest_first",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[3], 32)
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 120)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_text_only_inference(self):
|
||||||
|
"""Test that the processor works correctly with text-only input."""
|
||||||
|
processor_components = self.prepare_components()
|
||||||
|
processor_components["tokenizer"] = self.get_component("tokenizer", padding_side="left")
|
||||||
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
|
|
||||||
|
text = "This is a simple text without images."
|
||||||
|
inputs = processor(text=text)
|
||||||
|
|
||||||
|
tokenized_sentence = processor.tokenizer(text, add_special_tokens=False)
|
||||||
|
expected_input_ids = [tokenized_sentence["input_ids"]]
|
||||||
|
|
||||||
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
|
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
|
||||||
|
self.assertTrue("pixel_values" not in inputs)
|
||||||
|
self.assertTrue("pixel_attention_mask" not in inputs)
|
||||||
|
|
||||||
|
# Test batch of texts without image tokens
|
||||||
|
texts = ["First text.", "Second piece of text."]
|
||||||
|
batch_inputs = processor(text=texts, padding=True)
|
||||||
|
|
||||||
|
tokenized_1 = processor.tokenizer(texts[0], add_special_tokens=False)
|
||||||
|
tokenized_2 = processor.tokenizer(texts[1], add_special_tokens=False)
|
||||||
|
|
||||||
|
expected_1 = tokenized_1["input_ids"]
|
||||||
|
expected_2 = tokenized_2["input_ids"]
|
||||||
|
|
||||||
|
# Pad the shorter sequence
|
||||||
|
pad_len = len(expected_2) - len(expected_1)
|
||||||
|
if pad_len > 0:
|
||||||
|
padded_expected_1 = [self.padding_token_id] * pad_len + expected_1
|
||||||
|
expected_attention_1 = [0] * pad_len + [1] * len(expected_1)
|
||||||
|
self.assertEqual(batch_inputs["input_ids"], [padded_expected_1, expected_2])
|
||||||
|
self.assertEqual(batch_inputs["attention_mask"], [expected_attention_1, [1] * len(expected_2)])
|
||||||
|
else:
|
||||||
|
pad_len = -pad_len
|
||||||
|
padded_expected_2 = [self.padding_token_id] * pad_len + expected_2
|
||||||
|
expected_attention_2 = [0] * pad_len + [1] * len(expected_2)
|
||||||
|
self.assertEqual(batch_inputs["input_ids"], [expected_1, padded_expected_2])
|
||||||
|
self.assertEqual(batch_inputs["attention_mask"], [[1] * len(expected_1), expected_attention_2])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_missing_images_error(self):
|
||||||
|
"""Test that appropriate error is raised when images are referenced but not provided."""
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
# Test single text with image token but no image
|
||||||
|
text = "Let me show you this image: <image> What do you think?"
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
processor(text=text)
|
||||||
|
self.assertTrue("tokens in the text but no images/videos were passed" in str(context.exception))
|
||||||
|
|
||||||
|
# Test batch with image tokens but no images
|
||||||
|
texts = [
|
||||||
|
"First text with <image> token.",
|
||||||
|
"Second text <image> with token.",
|
||||||
|
]
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
processor(text=texts)
|
||||||
|
self.assertTrue("tokens in the text but no images/videos were passed" in str(context.exception))
|
||||||
|
|
||||||
|
# Test with None as Images
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
processor(text=text, images=None)
|
||||||
|
self.assertTrue("tokens in the text but no images/videos were passed" in str(context.exception))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
processor(text=texts, images=None)
|
||||||
|
self.assertTrue("tokens in the text but no images/videos were passed" in str(context.exception))
|
@ -57,6 +57,7 @@ from transformers.models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||||||
@ -262,6 +263,7 @@ class ModelTesterMixin:
|
|||||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES),
|
*get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES),
|
||||||
|
*get_values(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
||||||
|
@ -86,6 +86,7 @@ PRIVATE_MODELS = [
|
|||||||
"Idefics2PerceiverResampler",
|
"Idefics2PerceiverResampler",
|
||||||
"Idefics2VisionTransformer",
|
"Idefics2VisionTransformer",
|
||||||
"Idefics3VisionTransformer",
|
"Idefics3VisionTransformer",
|
||||||
|
"SmolVLMVisionTransformer",
|
||||||
"AriaTextForCausalLM",
|
"AriaTextForCausalLM",
|
||||||
"AriaTextModel",
|
"AriaTextModel",
|
||||||
]
|
]
|
||||||
|
@ -180,6 +180,7 @@ MODEL_NAMES_TO_IGNORE = [
|
|||||||
"CLIPVisionModel",
|
"CLIPVisionModel",
|
||||||
"Qwen2AudioEncoder",
|
"Qwen2AudioEncoder",
|
||||||
"SiglipVisionModel",
|
"SiglipVisionModel",
|
||||||
|
"SmolVLMVisionTransformer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user