* Fix converter

* [Broken] Adds Gemma 3 to Hugging Face Transformers

* Consolidating Config and Processor params across impls

* Sorting out configuration parameters. Adds qk_norm before RoPE. Still not sure if RoPE is right.

* Additional plumbing for CausalLM and ConditionalGeneration variants

* incomplete draft of Orbax conversion script

* More complete checkpoint conversion

* Supporting Gemma 3 1B checkpoints

* Updating RoPE for multiple frequencies

* Adjustments to rotary embedder

* Proof of life for text-only operation

* Updating the conversion script to handle multimodal projection weights

* Fixing tet-only conversions

* Cleaner conversion script with multimodal support and a simpler processor

* Additional refatcors to the Gemma3Processor

* Simplified Processor to work over text representations

* Updated conversion script to join text and vision embeddings at converion time

* Logging for debugging

* Update src/transformers/models/gemma2/modeling_gemma2.py

Co-authored-by: Joshua Lochner <admin@xenova.com>

* Removed extraneous Config params

* Switching to fast tokenizer for checkpoint conversions

* isolating siglip for performance tetsing

* Minor changes for debugging tests against baselines

* Adding average pooling for soft tokens

* Updating processor code to enable simpler embedding interleaving for arbitrary number of images in prompts

* Updating conversion script for ShieldGemma 2 conversion compatibility

* Allow disable_compile to be provided as a kwarg

* Refresh from modular

* Updated conversion script and corrected sliding window

* Fix type mismatch in cache_position (#4)

* Fix dtype (#5)

* Fix type mismatch in cache_position

* Actually fix in the modular file

Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>

---------

Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>

* fixes for embedding table overflow and missing image_soft_token_mask from Gemma3Processor

* Adding 2D pooling for image embeddings

* Revert "Adding 2D pooling for image embeddings"

This reverts commit 65350cf531.

* Gemma3 average pooling changed from 1D to 2D

* Major refactor to Gemma3MultimodalInputProjection

* Updating Gemm 3 Auto* registrations

* Add option to save Gemma 3 chat template with tokenizer during weights conversion

* Removing unused imports

* Moving out-of-vocab handling from Gemma3Processor to Gemma3ForConditionalGeneration

* Removing duplicate config property

* Removing final logit softcapping and 1-indexing of position ids

* Fixing image processor config and none --> None typo

* Fixing sliding window size for 1B

* Updating image_mean and image_std in Image Processor

* Attention masking changed to lower triangular

* Moving image special tokens to conversion script

* Mirror image processor defaults from conversion script into Gemma3ProcessorKwargs

* Remove special token variables from symbol space

* Moving image soft token mask computation from Gemma3Processor to Gemma3ForConditionalGeneration

* tie lm_head and embedding weights

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Correct tied weights in Gemma3CausalLM

* iterative bidirectional attention

* resolving merge conflicts

* Reverting to Gemma 2 HybridCache with sldiing window support and a sliding_window_pattern of 6

* Correcting RoPE scaling

* clean up first pass, dummy model geenration works

* final clean up before fixing tests

* causal lm test works, so fine

* Fix conversion

* Update src/transformers/models/gemma3/processing_gemma3.py

* model tests are happy

* processor tests are happy

* image processing tests added

* fixup

* Fix pre-processing in conversion

* Inputs merging

* Do not normalize vision embeddings

* Apply Ryan's (and team) changes to attention

* token type ids + mask

* template

* move embed scale, add rope scale, fix tests

* Add chat template to tokenizer

* Use prefix for causal model loading

* use existing code for sliding mask from gemma2

* self.embed_tokens already normalizes

* Correcting Gemma3TextConfig parameters in conversion script

* typo, modular overwrites my fixes

* enable device map for text model

* Conversion updates

* ultra nit: no einsums

* update image token

* copy deepcopy config + some docs

* add some test, still WIP

* Refactoring --include_chat_tempalte logic in converter

* Update src/transformers/models/gemma3/modular_gemma3.py

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* Add eos tokens for instruct models

* dump so i can work on dgx

* Removing add_bos by default

* dump

* add fast im proc

* docs for PaS + fixup

* another fixup

* one more fixup

* fix tests

* Inverting prior BOS change

* ultra nit

* Reverting to Tokenizer saved with add_bos_token=True and chat template starting with BOS

* resize embeds, remove sqrt, add slow test outputs

* FA2 but quality is meh

* nit

* skip FA2, no idea what happened

* last bit for green CI

* please, green CI for docs

* T_T

* Fix for Gemma3 logits

* Support both options for system prompt

* Update src/transformers/models/gemma3/image_processing_gemma3_fast.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/gemma3.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/gemma3.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/gemma3.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/gemma3.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/model_doc/gemma3.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Docs updates now that assets are live

* Style fixes

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: Lysandre <hi@lysand.re>
This commit is contained in:
Ryan Mullins 2025-03-12 04:06:17 -04:00 committed by GitHub
parent 81aa9b2e07
commit 50d3530aa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 5469 additions and 116 deletions

View File

@ -0,0 +1,203 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Gemma3
## Overview
The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq).
## Usage tips
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
- The text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it.
### Image cropping for high resolution images
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
```python
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
do_pan_and_scan=True,
).to(model.device)
```
## Usage Example
### Single-image Inference
```python
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
```
### Multi-image Inference
```python
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": url_cow},
{"type": "image", "url": url_stop},
{"type": "text", "text": "Are these two images identical?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
```
### Text-only inference
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
```python
from transformers import AutoTokenizer, Gemma3ForCausalLM
model_id = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=100)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(text)
```
## Gemma3ImageProcessor
[[autodoc]] Gemma3ImageProcessor
## Gemma3ImageProcessorFast
[[autodoc]] Gemma3ImageProcessorFast
## Gemma3Processor
[[autodoc]] Gemma3Processor
## Gemma3TextConfig
[[autodoc]] Gemma3TextConfig
## Gemma3Config
[[autodoc]] Gemma3Config
## Gemma3TextModel
[[autodoc]] Gemma3TextModel
- forward
## Gemma3ForCausalLM
[[autodoc]] Gemma3ForCausalLM
- forward
## Gemma3ForConditionalGeneration
[[autodoc]] Gemma3ForConditionalGeneration
- forward

View File

@ -474,6 +474,7 @@ _import_structure = {
"models.fuyu": ["FuyuConfig"],
"models.gemma": ["GemmaConfig"],
"models.gemma2": ["Gemma2Config"],
"models.gemma3": ["Gemma3Config", "Gemma3Processor", "Gemma3TextConfig"],
"models.git": [
"GitConfig",
"GitProcessor",
@ -1259,6 +1260,7 @@ else:
_import_structure["models.emu3"].append("Emu3ImageProcessor")
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
_import_structure["models.gemma3"].append("Gemma3ImageProcessor")
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
@ -1332,6 +1334,7 @@ else:
_import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
@ -2452,6 +2455,14 @@ else:
"Gemma2PreTrainedModel",
]
)
_import_structure["models.gemma3"].extend(
[
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3PreTrainedModel",
"Gemma3TextModel",
]
)
_import_structure["models.git"].extend(
[
"GitForCausalLM",
@ -2554,6 +2565,7 @@ else:
"GraniteMoePreTrainedModel",
]
)
_import_structure["models.granitemoeshared"].extend(
[
"GraniteMoeSharedForCausalLM",
@ -2561,7 +2573,6 @@ else:
"GraniteMoeSharedPreTrainedModel",
]
)
_import_structure["models.grounding_dino"].extend(
[
"GroundingDinoForObjectDetection",
@ -5629,6 +5640,7 @@ if TYPE_CHECKING:
from .models.fuyu import FuyuConfig
from .models.gemma import GemmaConfig
from .models.gemma2 import Gemma2Config
from .models.gemma3 import Gemma3Config, Gemma3Processor, Gemma3TextConfig
from .models.git import (
GitConfig,
GitProcessor,
@ -6450,6 +6462,7 @@ if TYPE_CHECKING:
FlavaProcessor,
)
from .models.fuyu import FuyuImageProcessor, FuyuProcessor
from .models.gemma3 import Gemma3ImageProcessor
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
from .models.got_ocr2 import GotOcr2ImageProcessor
from .models.grounding_dino import GroundingDinoImageProcessor
@ -6535,6 +6548,7 @@ if TYPE_CHECKING:
from .models.deit import DeiTImageProcessorFast
from .models.depth_pro import DepthProImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.gemma3 import Gemma3ImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
@ -7461,6 +7475,12 @@ if TYPE_CHECKING:
Gemma2Model,
Gemma2PreTrainedModel,
)
from .models.gemma3 import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3PreTrainedModel,
Gemma3TextModel,
)
from .models.git import (
GitForCausalLM,
GitModel,

View File

@ -113,10 +113,10 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
# there is a missing token in the vocab. We have to do this to support merges
# If "\t" is missing in the vocab, we have to do this to support merges
# "<0x09>" is the bytefallback for `\t`
vocab["\t"] = vocab.get("<0x09>")
if "\t" not in vocab:
vocab["\t"] = vocab.get("<0x09>")
merges = generate_merges(vocab, vocab_scores)
return vocab, merges
@ -1296,12 +1296,14 @@ class GemmaConverter(SpmConverter):
(self.original_tokenizer.eos_token, 0.0),
(self.original_tokenizer.bos_token, 0.0),
]
for piece in proto.pieces[3:]:
if piece.piece == "<0x09>":
vocab += [("\t", piece.score)]
else:
vocab += [(piece.piece, piece.score)]
# vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# Older gemma tokenizers had a missing tab token, so we fix that here
if not any(x[0] == "\t" for x in vocab):
override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
if override_index is not None:
vocab[override_index] = ("\t", 0.0)
return vocab
def pre_tokenizer(self, replacement, add_prefix_space):

View File

@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
for serialized_param_name, empty_param in state_dict.items():
if serialized_param_name not in expected_keys:
continue
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched
if shard_file.endswith(".safetensors"):
param = file_pointer.get_slice(serialized_param_name)

View File

@ -106,6 +106,7 @@ from . import (
fuyu,
gemma,
gemma2,
gemma3,
git,
glm,
glpn,

View File

@ -124,6 +124,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("gemma3", "Gemma3Config"),
("gemma3_text", "Gemma3TextConfig"),
("git", "GitConfig"),
("glm", "GlmConfig"),
("glpn", "GLPNConfig"),
@ -459,6 +461,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"),
("git", "GIT"),
("glm", "GLM"),
("glpn", "GLPN"),
@ -748,6 +752,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("qwen2_audio_encoder", "qwen2_audio"),
("clip_text_model", "clip"),
("aria_text", "aria"),
("gemma3_text", "gemma3"),
("idefics3_vision", "idefics3"),
("siglip_vision_model", "siglip"),
("smolvlm_vision", "smolvlm"),

View File

@ -86,6 +86,7 @@ else:
("flava", ("FlavaImageProcessor",)),
("focalnet", ("BitImageProcessor",)),
("fuyu", ("FuyuImageProcessor",)),
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)),
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),

View File

@ -118,6 +118,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("gemma3_text", "Gemma3TextModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glpn", "GLPNModel"),
@ -338,6 +339,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("fnet", "FNetForPreTraining"),
("fsmt", "FSMTForConditionalGeneration"),
("funnel", "FunnelForPreTraining"),
("gemma3", "Gemma3ForConditionalGeneration"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("gpt_bigcode", "GPTBigCodeForCausalLM"),
@ -518,6 +520,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForCausalLM"),
("gemma3_text", "Gemma3ForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
@ -824,6 +828,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("chameleon", "ChameleonForConditionalGeneration"),
("emu3", "Emu3ForConditionalGeneration"),
("fuyu", "FuyuForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"),
("git", "GitForCausalLM"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
("idefics", "IdeficsForVisionText2Text"),

View File

@ -63,6 +63,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("gemma3", "Gemma3Processor"),
("git", "GitProcessor"),
("got_ocr2", "GotOcr2Processor"),
("grounding-dino", "GroundingDinoProcessor"),

View File

@ -215,6 +215,13 @@ else:
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"gemma3",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),

View File

@ -41,7 +41,6 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -936,42 +935,23 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
if (
isinstance(past_key_values, HybridCache)
@ -994,19 +974,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["attention_mask"] = attention_mask
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs

View File

@ -29,7 +29,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import is_torchdynamo_compiling, logging
from ...utils import logging
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaForCausalLM,
@ -686,42 +686,23 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
if (
isinstance(past_key_values, HybridCache)
@ -744,19 +725,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["attention_mask"] = attention_mask
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs

View File

@ -0,0 +1,30 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_gemma3 import *
from .image_processing_gemma3 import *
from .image_processing_gemma3_fast import *
from .modeling_gemma3 import *
from .processing_gemma3 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,330 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
from ..siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
class Gemma3TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma3Text-7B.
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262208):
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma3TextModel`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
Scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
>>> configuration = Gemma3TextConfig()
>>> # Initializing a model from the gemma3_text-7b style configuration
>>> model = Gemma3TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
"""
model_type = "gemma3_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=262_208,
hidden_size=2304,
intermediate_size=9216,
num_hidden_layers=26,
num_attention_heads=8,
num_key_value_heads=4,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=131_072,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=1_000_000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=256,
sliding_window=4096,
final_logit_softcapping=None,
attn_logit_softcapping=None,
cache_implementation="hybrid",
rope_scaling=None,
rope_local_base_freq=10_000.0,
sliding_window_pattern=6,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.rope_scaling = rope_scaling
rope_config_validation(self)
class Gemma3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)
__all__ = ["Gemma3Config", "Gemma3TextConfig"]

View File

@ -0,0 +1,592 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \
--variant='gemma3_4b' \
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" \
--precision='bfloat16'
"""
import dataclasses
from collections.abc import Iterator, Sequence
from typing import Any
import accelerate
import numpy as np
import torch
import tree
from absl import app, flags, logging
from orbax import checkpoint as obc
from ...image_utils import PILImageResampling
from ..gemma import GemmaTokenizerFast
from . import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3ImageProcessor,
Gemma3Processor,
)
from .configuration_gemma3 import (
Gemma3Config,
Gemma3TextConfig,
SiglipVisionConfig,
)
# ==== Internal Constants and Classes ====
_CHAT_TEMPLATE = """{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- if messages[0]['content'] is string -%}
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
{%- else -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
{%- endif -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>\n' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model\n'}}
{%- endif -%}
"""
_DTYPES = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder"
_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding"
_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_"
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
_TRANSFORMER_EMBEDDER = "transformer/embedder"
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
_VISION_CONFIG = {
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"num_channels": 3,
"image_size": 896,
"patch_size": 14,
"hidden_act": "gelu_pytorch_tanh",
"layer_norm_eps": 1e-6,
"attention_dropout": 0.0,
"vision_use_head": False,
}
_VARIANT_GEMMA_3_1B = "gemma3_1b"
_VARIANT_GEMMA_3_4B = "gemma3_4b"
_VARIANT_GEMMA_3_12B = "gemma3_12b"
_VARIANT_GEMMA_3_27B = "gemma3_27b"
_VARIANTS = {
_VARIANT_GEMMA_3_1B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_144,
hidden_size=1152,
intermediate_size=6 * 1152,
num_attention_heads=4,
num_hidden_layers=26,
num_key_value_heads=1,
head_dim=256,
sliding_window=512,
rope_theta=1_000_000, # used for global RoPE only
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
query_pre_attn_scalar=256,
max_position_embeddings=32_768,
),
vision_config=None,
),
_VARIANT_GEMMA_3_4B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
hidden_size=2560,
intermediate_size=2560 * 8 // 2,
num_attention_heads=8,
head_dim=256,
num_hidden_layers=34,
num_key_value_heads=4,
sliding_window=1024,
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
query_pre_attn_scalar=256,
),
vision_config=_VISION_CONFIG,
),
_VARIANT_GEMMA_3_12B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
hidden_size=30 * 128,
intermediate_size=30 * 128 * 8 // 2,
num_attention_heads=16,
head_dim=256,
num_hidden_layers=48,
num_key_value_heads=8,
sliding_window=1024,
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
query_pre_attn_scalar=256,
),
vision_config=_VISION_CONFIG,
),
_VARIANT_GEMMA_3_27B: Gemma3Config(
text_config=Gemma3TextConfig(
vocab_size=262_208,
hidden_size=42 * 128,
intermediate_size=42 * 128 * 8 // 2,
num_attention_heads=32,
num_hidden_layers=62,
num_key_value_heads=16,
head_dim=128,
sliding_window=1024,
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads)
),
vision_config=_VISION_CONFIG,
),
}
# ==== Flags ====
CHECKPOINT_PATH = flags.DEFINE_string(
name="checkpoint_path",
default=None,
help="Path to the Orbax checkpoint.",
required=True,
)
INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
)
OUTPUT_PATH = flags.DEFINE_string(
name="output_path",
default=None,
help="Path to store the HF checkpoint.",
required=True,
)
PRECISION = flags.DEFINE_enum(
name="precision",
default=None,
help="The floating point precision (aka dtype) of the model.",
enum_values=set(_DTYPES.keys()),
required=True,
)
_TEXT_ONLY = flags.DEFINE_bool(
name="text_only",
default=False,
help=(
"If True, the model is loaded and saved as a Gemma3ForCausalLM, "
"otherwise model saed as Gemma3ForConditionalGeneration."
),
)
TOKENIZER_PATH = flags.DEFINE_string(
name="tokenizer_path",
default=None,
help="Path to the SentencePiece model file.",
required=True,
)
_VARIANT = flags.DEFINE_enum(
name="variant",
default=_VARIANT_GEMMA_3_4B,
help="The model variant to convert.",
enum_values=set(_VARIANTS.keys()),
)
def convert_siglip_weight(
config: SiglipVisionConfig,
paths: Sequence[str],
weights: np.ndarray,
) -> tuple[str, np.ndarray]:
path, prop = paths
normalized_path: str = ""
updated_weights: np.ndarray = None
if path == _SIGLIP_BASE:
normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight"
updated_weights = weights.reshape(-1, config.hidden_size)
elif path == _SIGLIP_EMBEDDING:
if prop == "kernel":
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight"
updated_weights = weights.transpose(3, 2, 0, 1)
elif prop == "bias":
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK):
encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:]
next_path_seperator_idx = encoder_block_path.find("/")
layer_idx = encoder_block_path[:next_path_seperator_idx]
encoder_block_path = encoder_block_path[next_path_seperator_idx:]
normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
if encoder_block_path.startswith("/LayerNorm"):
normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2"
if prop == "scale":
normalized_path += ".weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path += ".bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
elif encoder_block_path.startswith("/MlpBlock_0"):
normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2"
if prop == "kernel":
normalized_path += ".weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path += ".bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"):
if encoder_block_path.endswith("/key"):
normalized_path += ".self_attn.k_proj"
elif encoder_block_path.endswith("/out"):
normalized_path += ".self_attn.out_proj"
elif encoder_block_path.endswith("/query"):
normalized_path += ".self_attn.q_proj"
elif encoder_block_path.endswith("/value"):
normalized_path += ".self_attn.v_proj"
else:
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.")
if prop == "bias":
normalized_path += ".bias"
updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1)
elif prop == "kernel":
normalized_path += ".weight"
updated_weights = weights.reshape(-1, config.hidden_size).transpose()
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
else:
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.")
elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM:
if prop == "scale":
normalized_path = "vision_tower.vision_model.post_layernorm.weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path = "vision_tower.vision_model.post_layernorm.bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
else:
raise ValueError(f"Unexpected path `{path}`.")
if "vision" in normalized_path:
print(normalized_path)
return normalized_path, updated_weights
def convert_transformer_weights(
config: Gemma3TextConfig,
paths: Sequence[str],
weights: np.ndarray,
) -> Iterator[tuple[str, np.ndarray]]:
path, prop = paths
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
converted_paths: list[str] = []
converted_weights: list[Any] = []
attn_head_dim = config.num_attention_heads * config.head_dim
kv_head_dim = config.num_key_value_heads * config.head_dim
if path == _TRANSFORMER_EMBEDDER:
if prop == "input_embedding":
# Tied to language_model.lm_head.weight, assigned at the end.
converted_paths = ["language_model.model.embed_tokens.weight"]
if not _TEXT_ONLY.value:
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
pre_expansion_embeddings = weights
mu = np.mean(pre_expansion_embeddings, axis=0)
sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
converted_weights = [weights]
elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
return zip([], [])
else:
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
if _TEXT_ONLY.value:
return zip([], [])
if path.endswith("/mm_input_projection"):
converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
converted_weights = [weights]
elif path.endswith("/mm_soft_embedding_norm"):
converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
converted_weights = [weights]
else:
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
elif path == _TRANSFORMER_FINAL_NORM:
converted_paths = ["language_model.model.norm.weight"]
converted_weights = [weights]
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
next_path_seperator_idx = decoder_block_path.find("/")
layer_idx = decoder_block_path[:next_path_seperator_idx]
decoder_block_path = decoder_block_path[next_path_seperator_idx:]
base_path = f"language_model.model.layers.{layer_idx}"
if path.endswith("attn/attn_vec_einsum"):
converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)]
elif path.endswith("attn/_key_norm"):
converted_paths = [f"{base_path}.self_attn.k_norm.weight"]
converted_weights = [weights]
elif path.endswith("attn/kv_einsum"):
converted_paths = [
f"{base_path}.self_attn.k_proj.weight",
f"{base_path}.self_attn.v_proj.weight",
]
k_proj_weights, v_proj_weights = weights
converted_weights = [
k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
]
elif path.endswith("attn/q_einsum"):
converted_paths = [f"{base_path}.self_attn.q_proj.weight"]
converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)]
elif path.endswith("attn/_query_norm"):
converted_paths = [f"{base_path}.self_attn.q_norm.weight"]
converted_weights = [weights]
elif path.endswith("mlp/gating_einsum"):
converted_paths = [
f"{base_path}.mlp.gate_proj.weight",
f"{base_path}.mlp.up_proj.weight",
]
gate_proj_weight, up_proj_weight = weights
converted_weights = [gate_proj_weight, up_proj_weight]
elif path.endswith("mlp/linear"):
converted_paths = [f"{base_path}.mlp.down_proj.weight"]
converted_weights = [weights.transpose()]
elif path.endswith("post_attention_norm"):
converted_paths = [f"{base_path}.post_attention_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("post_ffw_norm"):
converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("pre_attention_norm"):
converted_paths = [f"{base_path}.input_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("pre_ffw_norm"):
converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"]
converted_weights = [weights]
else:
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
else:
raise ValueError(f"Unexpected path `{path}`.")
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
raise ValueError(
"The `converted_paths` and `converted_weights` should be the same "
f"length. Got {cpl} and {cwl}, respectively, for {path}."
)
return zip(converted_paths, converted_weights)
@dataclasses.dataclass(frozen=True)
class ConversionResult:
state_tree: dict[str, torch.Tensor]
config: Gemma3Config
def convert(
checkpoint_path: str,
config: Gemma3Config,
target_dtype: torch.dtype,
) -> ConversionResult:
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
checkpointer = obc.PyTreeCheckpointer()
ckpt = checkpointer.restore(checkpoint_path)
hf_tree: dict[str, torch.Tensor] = {}
def update_tree(path: str, weights: np.ndarray) -> None:
torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype)
logging.info(
"%s converted shape=%s with dtype=%s",
path,
weights.shape,
torch_tensor.dtype,
)
hf_tree[path] = torch_tensor
for paths, value in tree.flatten_with_path(ckpt):
if paths[0].startswith("SigLiPFromPatches_"):
if config.vision_config is None:
continue
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
update_tree(path, weights)
else:
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
if config.vision_config is None:
path = path[len("language_model.") :]
update_tree(path, weights)
if config.vision_config is None:
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]
else:
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
return ConversionResult(state_tree=hf_tree, config=config)
def main(*args):
del args
variant = _VARIANT.value
dtype = getattr(torch, PRECISION.value)
config = _VARIANTS[variant]
output_path = OUTPUT_PATH.value
if variant == _VARIANT_GEMMA_3_1B:
flags.FLAGS.set_default(_TEXT_ONLY.name, True)
tokenizer = GemmaTokenizerFast(
TOKENIZER_PATH.value,
add_bos_token=True,
extra_special_tokens={
"image_token": "<image_soft_token>", # Should be ID=262_144
"boi_token": "<start_of_image>", # Should be ID=255_999
"eoi_token": "<end_of_image>", # Should be ID=256_000
},
)
if INCLUDE_CHAT_TEMPLATE.value:
# Include chat template for CausalLM models
tokenizer.chat_template = _CHAT_TEMPLATE
config.eos_token_id = [1, 106]
if _TEXT_ONLY.value:
config.vision_config = None
tokenizer.save_pretrained(output_path)
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
del tokenizer
else:
image_processor = Gemma3ImageProcessor(
image_seq_length=256,
image_mean=(0.5,) * 3,
image_std=(0.5,) * 3,
size={"height": 896, "width": 896},
resample=PILImageResampling.BILINEAR,
)
processor = Gemma3Processor(
image_processor=image_processor,
tokenizer=tokenizer,
)
if INCLUDE_CHAT_TEMPLATE.value:
# Duplicate so multimodal instruct models can also be used for CausalLM
processor.chat_template = tokenizer.chat_template
processor.save_pretrained(output_path)
logging.info("Saved Gemma3Processor for %s to %s", variant, output_path)
del processor
del tokenizer
logging.info("Gemma 3 (%s) configured as: %s", variant, config)
logging.info("Converting Gemma 3 (%s) @ %s", variant, dtype)
result = convert(CHECKPOINT_PATH.value, config, dtype)
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
with accelerate.init_empty_weights():
if config.vision_config is None:
model = Gemma3ForCausalLM(config=config.text_config)
else:
model = Gemma3ForConditionalGeneration(config)
model.load_state_dict(result.state_tree, assign=True, strict=True)
model.config.torch_dtype = dtype
logging.info("Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", variant, type(model).__name__)
model.save_pretrained(output_path, safe_serialization=True)
logging.info(
"Saved Gemma 3 (%s) to SafeTensors in %s using %s",
variant,
output_path,
type(model).__name__,
)
del model
del result
if __name__ == "__main__":
app.run(main)

View File

@ -0,0 +1,413 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Gemma3."""
import itertools
import math
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_nested_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class Gemma3ImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
"""
model_input_names = ["pixel_values", "num_crops"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size, default_to_square=True)
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pan_and_scan = do_pan_and_scan
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
def pan_and_scan(
self,
image: np.ndarray,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds
minumum allowed ratio.
Args:
image (`np.ndarray`):
Image to resize.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height, width = get_image_size(image)
# Square or landscape image.
if width >= height:
# Only apply PaS if the image is sufficiently exaggerated
if width / height < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
# Portrait image.
else:
# Only apply PaS if the image is sufficiently exaggerated
if height / width < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_h = int(math.floor(height / width + 0.5))
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(width / num_crops_w))
crop_size_h = int(math.ceil(height / num_crops_h))
# Don't apply PaS if crop size is too small.
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return []
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
if input_data_format == ChannelDimension.LAST:
image_crops = [
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
]
else:
image_crops = [
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
]
return image_crops
def _process_images_for_pan_and_scan(
self,
images: List[np.ndarray],
do_pan_and_scan: bool,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
pas_images_list = []
num_crops = []
for image in images:
pas_images = self.pan_and_scan(
image=image,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
pas_images_list.extend([image] + pas_images)
num_crops.append(len(pas_images))
return pas_images_list, num_crops
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`):
Minimum aspect ratio to activate pan and scan.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_pan_and_scan = do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
pan_and_scan_min_crop_size = (
pan_and_scan_min_crop_size if pan_and_scan_min_crop_size is not None else self.pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops = (
pan_and_scan_max_num_crops if pan_and_scan_max_num_crops is not None else self.pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate = (
pan_and_scan_min_ratio_to_activate
if pan_and_scan_min_ratio_to_activate is not None
else self.pan_and_scan_min_ratio_to_activate
)
images_list = make_nested_list_of_images(images)
if not valid_images(images_list[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
if do_rescale and is_scaled_image(images_list[0][0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
if do_pan_and_scan:
images_list_and_num_crops = [
self._process_images_for_pan_and_scan(
images=images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
for images in images_list
]
images_list = [images for images, _ in images_list_and_num_crops]
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
else:
num_crops = [[0] for images in images_list]
processed_images = []
for images in images_list:
for image in images:
if do_resize:
height, width = size["height"], size["width"]
image = resize(
image=image, size=(height, width), resample=resample, input_data_format=input_data_format
)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
processed_images.append(image)
data = {"pixel_values": processed_images, "num_crops": num_crops}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma3ImageProcessor"]

View File

@ -0,0 +1,387 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for SigLIP."""
import itertools
import math
from functools import partial
from typing import List, Optional, Union
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
get_size_dict,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
SizeDict,
get_image_size,
make_nested_list_of_images,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)
if is_vision_available():
from ...image_utils import PILImageResampling
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
class Gemma3FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
class Gemma3FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of Pan adn Scan cropping method.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
""",
)
class Gemma3ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 224, "width": 224}
default_to_square = True
do_resize = True
do_rescale = True
do_normalize = True
do_pan_and_scan = None
pan_and_scan_min_crop_size = None
pan_and_scan_max_num_crops = None
pan_and_scan_min_ratio_to_activate = None
valid_init_kwargs = Gemma3FastImageProcessorInitKwargs
valid_preprocess_kwargs = Gemma3FastImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorInitKwargs]):
super().__init__(**kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_nested_list_of_images(images)
def _prepare_input_images(
self,
images: ImageInput,
do_convert_rgb: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> List["torch.Tensor"]:
"""
Prepare the input images for processing.
"""
batch_images = self._prepare_images_structure(images)
process_image_fn = partial(
self._process_image,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# todo: yoni - check if we can parallelize this efficiently
batch_processed_images = []
for image_list in batch_images:
processed_images = []
for image in image_list:
processed_images.append(process_image_fn(image))
batch_processed_images.append(processed_images)
return batch_processed_images
def pan_and_scan(
self,
image: "torch.Tensor",
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
):
"""
Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
minumum allowed ratio.
Args:
image (`torch.Tensor`):
Image to resize.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
"""
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
# Square or landscape image.
if width >= height:
# Only apply PaS if the image is sufficiently exaggerated
if width / height < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
# Portrait image.
else:
# Only apply PaS if the image is sufficiently exaggerated
if height / width < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_h = int(math.floor(height / width + 0.5))
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(width / num_crops_w))
crop_size_h = int(math.ceil(height / num_crops_h))
# Don't apply PaS if crop size is too small.
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return []
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
return [
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
]
def _process_images_for_pan_and_scan(
self,
images: List["torch.Tensor"],
do_pan_and_scan: bool,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
):
pas_images_list = []
num_crops = []
for image in images:
pas_images = self.pan_and_scan(
image=image,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
pas_images_list.extend([image] + pas_images)
num_crops.append(len(pas_images))
return pas_images_list, num_crops
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
pan_and_scan_min_crop_size (`int`, *optional*):
Minimum size of each crop in pan and scan.
pan_and_scan_max_num_crops (`int`, *optional*):
Maximum number of crops per image in pan and scan.
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
Minimum aspect ratio to activate pan and scan.
""",
)
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[Gemma3FastImageProcessorPreprocessKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys()
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_preprocess_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0][0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
def _preprocess(
self,
images: List[List["torch.Tensor"]],
do_resize: bool,
size: SizeDict,
do_pan_and_scan: Optional[bool],
pan_and_scan_min_crop_size: Optional[int],
pan_and_scan_max_num_crops: Optional[int],
pan_and_scan_min_ratio_to_activate: Optional[float],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
batch_num_crops = []
for image_list in images:
if do_pan_and_scan:
images_list, num_crops = self._process_images_for_pan_and_scan(
images=image_list,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
else:
num_crops = [[0] for images in images_list]
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.extend(processed_image_patches)
batch_num_crops.extend(num_crops)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors
)
__all__ = ["Gemma3ImageProcessorFast"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,848 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
ModelOutput,
)
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
logging,
)
from ..bart.modeling_bart import BartScaledWordEmbedding
from ..gemma2.configuration_gemma2 import Gemma2Config
from ..gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2ForCausalLM,
Gemma2MLP,
Gemma2Model,
Gemma2PreTrainedModel,
Gemma2RMSNorm,
Gemma2RotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from ..siglip import SiglipVisionConfig
_CHECKPOINT_FOR_DOC = "google/gemma-3-4b"
_CONFIG_FOR_DOC = "Gemma3Config"
logger = logging.get_logger(__name__)
GEMMA3_INPUTS_DOCSTRING = ""
class Gemma3TextConfig(Gemma2Config):
r"""
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma3Text-7B.
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262208):
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma3TextModel`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
Scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
size of the sliding window.
final_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*):
Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
```python
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
>>> configuration = Gemma3TextConfig()
>>> # Initializing a model from the gemma3_text-7b style configuration
>>> model = Gemma3TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
"""
model_type = "gemma3_text"
def __init__(
self,
vocab_size=262_208,
rope_theta=1_000_000.0,
rope_scaling=None,
rope_local_base_freq=10_000.0,
sliding_window_pattern=6,
max_position_embeddings=131_072,
final_logit_softcapping=None,
attn_logit_softcapping=None,
**super_kwargs,
):
super().__init__(self, **super_kwargs)
self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.rope_scaling = rope_scaling
rope_config_validation(self)
class Gemma3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
Base class for Gemma3 causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class Gemma3TextScaledWordEmbedding(BartScaledWordEmbedding):
pass
class Gemma3MLP(Gemma2MLP):
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)
class Gemma3RMSNorm(Gemma2RMSNorm):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
def __init__(self, config: Gemma3TextConfig, device=None):
super().__init__(config)
# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
class Gemma3Attention(Gemma2Attention):
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
super().__init__()
self.sliding_window = config.sliding_window if self.is_sliding else None
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
"Falling back to eager attention. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask.to(query_states),
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Gemma3DecoderLayer(nn.Module):
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(attention_mask.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
GEMMA3_START_DOCSTRING = None
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
base_model_prefix = "language_model"
_no_split_modules = [
"Gemma3DecoderLayer",
"SiglipVisionEmbeddings",
"SiglipEncoderLayer",
"SiglipMultiheadAttentionPoolingHead",
]
def _init_weights(self, module):
# important: this ported version of Gemma2 isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Gemma3TextModel(Gemma2Model):
config_class = Gemma3TextConfig
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
self.embed_tokens = Gemma3TextScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
)
# TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
config = copy.deepcopy(config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
# embed positions
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings_global,
position_embeddings_local,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
class Gemma3ForCausalLM(Gemma2ForCausalLM):
config_class = Gemma3TextConfig
base_model_prefix = "language_model"
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)
self.model = Gemma3TextModel(config)
class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
def tie_weights(self):
return self.language_model.tie_weights()
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def _update_causal_mask(
self,
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_tensor,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
return attention_mask
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted
# form and requires no inversion or slicing.
return attention_mask
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
# Apply bidirectional mask on images if token type ids are provided
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"Gemma3Config",
"Gemma3TextConfig",
"Gemma3PreTrainedModel", # noqa: F822
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
]

View File

@ -0,0 +1,172 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import to_py_obj
class Gemma3ImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Gemma3ImagesKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"do_pan_and_scan": False,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
},
}
class Gemma3Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template", "image_seq_length"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
tokenizer,
chat_template=None,
image_seq_length: int = 256,
**kwargs,
):
self.image_seq_length = image_seq_length
self.image_token_id = tokenizer.image_token_id
self.boi_token = tokenizer.boi_token
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
super().__init__(
image_processor=image_processor,
tokenizer=tokenizer,
chat_template=chat_template,
**kwargs,
)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
videos=None,
audio=None,
**kwargs: Unpack[Gemma3ProcessorKwargs],
) -> BatchFeature:
if text is None and images is None:
raise ValueError("Provide at least one of `text` or `images`.")
output_kwargs = self._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
image_inputs = {}
if images is not None:
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
# Create empty text to be replaced with placeholders
if not text:
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
if len(batched_images) != len(text):
raise ValueError(
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
)
# Replace image tokens by the full expanded sequence
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
text_with_crops = text
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
if len(images) != len(image_indexes):
raise ValueError(
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
)
# Insert additional image tokens for Pan-and-Scan crops
for num, idx in reversed(list(zip(num_crops, image_indexes))):
if num:
formatted_image_text = (
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
+ " ".join([self.boi_token] * num)
)
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
text_with_crops[batch_idx] = prompt
# Expand placeholder image tokens to the full image token sequence
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["Gemma3Processor"]

View File

@ -477,11 +477,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -490,8 +485,16 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -506,10 +509,16 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "

View File

@ -4609,6 +4609,34 @@ class Gemma2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Gemma3ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma3ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Gemma3TextModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GitForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class Gemma3ImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class GotOcr2ImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

View File

@ -289,6 +289,13 @@ class FuyuProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Gemma3ImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class GLPNFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]

View File

@ -124,6 +124,7 @@ VLM_CLASS_NAMES = [
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"gemma3",
]

View File

@ -353,7 +353,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_Gemma_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)

View File

@ -153,6 +153,13 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_sdpa_equivalence(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@slow
@require_torch_gpu

View File

View File

@ -0,0 +1,229 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import Gemma3ImageProcessor
if is_torchvision_available():
from transformers import Gemma3ImageProcessorFast
class Gemma3ImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_normalize=True,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
do_convert_rgb=True,
do_pan_and_scan=True,
pan_and_scan_min_crop_size=10,
pan_and_scan_max_num_crops=2,
pan_and_scan_min_ratio_to_activate=1.2,
):
super().__init__()
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pan_and_scan = do_pan_and_scan
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"do_pan_and_scan": self.do_pan_and_scan,
"pan_and_scan_min_crop_size": self.pan_and_scan_min_crop_size,
"pan_and_scan_max_num_crops": self.pan_and_scan_max_num_crops,
"pan_and_scan_min_ratio_to_activate": self.pan_and_scan_min_ratio_to_activate,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Gemma3ImageProcessor if is_vision_available() else None
fast_image_processing_class = Gemma3ImageProcessorFast if is_torchvision_available() else None
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Gemma3
def setUp(self):
super().setUp()
self.image_processor_tester = Gemma3ImageProcessingTester(self)
@property
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_pan_and_scan"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_crop_size"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_max_num_crops"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_ratio_to_activate"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=84)
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
def test_pan_and_scan(self):
"""
Enables Pan and Scan path by choosing the correct input image resolution. If you are changing
image processor attributes for PaS, please update this test.
"""
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
"""This function prepares a list of PIL images"""
image_inputs = [np.random.randint(255, size=(3, 300, 600), dtype=np.uint8)] * 3
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
# Test not batched input, 3 images because we have base image + 2 crops
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (3, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
def test_call_numpy_4_channels(self):
pass

View File

@ -0,0 +1,520 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import unittest
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...models.gemma.test_modeling_gemma import GemmaModelTester
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3Processor,
Gemma3TextModel,
)
class Gemma3ModelTester(GemmaModelTester):
if is_torch_available():
config_class = Gemma3TextConfig
model_class = Gemma3TextModel
for_causal_lm_class = Gemma3ForCausalLM
@require_torch
class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3TextModel, Gemma3ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Gemma3ForCausalLM,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
def setUp(self):
self.model_tester = Gemma3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
class Gemma3Vision2TextModelTester:
def __init__(
self,
parent,
mm_tokens_per_image=2,
image_token_index=1,
boi_token_index=2,
eoi_token_index=3,
seq_length=25,
is_training=True,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
):
self.parent = parent
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.mm_tokens_per_image = mm_tokens_per_image
self.image_token_index = image_token_index
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.llm_tester = Gemma3ModelTester(self.parent)
self.text_config = self.llm_tester.get_config()
self.vision_config = vision_config
self.seq_length = seq_length
self.pad_token_id = self.text_config.pad_token_id
self.num_hidden_layers = self.text_config.num_hidden_layers
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
self.num_attention_heads = self.text_config.num_attention_heads
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
def get_config(self):
return Gemma3Config(
text_config=self.text_config,
vision_config=self.vision_config,
image_token_index=self.image_token_index,
boi_token_index=self.boi_token_index,
eoi_token_index=self.eoi_token_index,
mm_tokens_per_image=self.mm_tokens_per_image,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, :1] = config.image_token_index
token_type_ids = torch.zeros_like(input_ids)
token_type_ids[input_ids == config.image_token_index] = 1
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
return config, inputs_dict
@require_torch
class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
test_missing_keys = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
def setUp(self):
self.model_tester = Gemma3Vision2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)
def test_initialization(self):
pass
@unittest.skip(
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
)
def test_flex_attention_with_grads(self):
pass
@slow
@require_torch_gpu
# @require_read_token
class Gemma3IntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_model_4b_bf16(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_batch(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_multiimage(self):
model_id = "gg-hf-g/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What do you see here?"},
],
},
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_1b_text_only(self):
model_id = "gg-hf-g/gemma-3-1b-it"
model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
# TODO: raushan FA2 generates gibberish for no reason, check later
# @require_flash_attn
# @require_torch_gpu
# @mark.flash_attn_test
# def test_model_4b_flash_attn(self):
# model_id = "gg-hf-g/gemma-3-4b-it"
#
# model = Gemma3ForConditionalGeneration.from_pretrained(
# model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
# ).to(torch_device)
#
# inputs = self.processor.apply_chat_template(
# self.messages,
# tokenize=True,
# return_dict=True,
# return_tensors="pt",
# add_generation_prompt=True,
# ).to(torch_device)
#
# output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# output_text = self.processor.batch_decode(output, skip_special_tokens=True)
#
# EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip
# self.assertEqual(output_text, EXPECTED_TEXTS)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@ -0,0 +1,136 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from typing import Optional
from transformers import Gemma3Processor, GemmaTokenizer
from transformers.testing_utils import get_tests_dir, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Gemma3ImageProcessor
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_vision
class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Gemma3Processor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
gemma3_image_processor_kwargs = {
"do_pan_and_scan": True,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
}
image_processor = Gemma3ImageProcessor.from_pretrained(
"google/siglip-so400m-patch14-384", **gemma3_image_processor_kwargs
)
extra_special_tokens = {
"image_token": "<image_soft_token>",
"boi_token": "<start_of_image>",
"eoi_token": "<end_of_image>",
}
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens)
processor_kwargs = self.prepare_processor_dict()
processor = Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# TODO: raushan or arthur: add the real chat template
def prepare_processor_dict(self):
return {
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n", "image_seq_length": 3,
} # fmt: skip
# Override as VLMs need image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <start_of_image>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <start_of_image>"]
return ["lower newer <start_of_image>", "<start_of_image> upper older longer string"] + [
"<start_of_image> lower newer"
] * (batch_size - 2)
# Override as Gemma3 needs images to be an explicitly nested batch
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
images = super().prepare_image_inputs(batch_size)
if isinstance(images, (list, tuple)):
images = [[image] for image in images]
return images
def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!"
text_single_image = f"{processor.boi_token}Dummy text!"
text_no_image = "Dummy text!"
image = self.prepare_image_inputs()
# If text has no image tokens, iamge should be `None`
with self.assertRaises(ValueError):
_ = processor(text=text_no_image, images=image, return_tensors="np")
# We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text
with self.assertRaises(ValueError):
_ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np")
# The users is expected to be explicit about which image belong to which text by nesting the images list
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
out_batch_oneimage = processor(
text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np"
)
self.assertListEqual(
out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist()
)
def test_pan_and_scan(self):
processor_components = self.prepare_components()
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="np",
do_pan_and_scan=True,
image_seq_length=2,
pan_and_scan_min_crop_size=10,
)
# base image + 4 crops
self.assertEqual(len(inputs[self.images_input_name]), 5)
self.assertEqual(len(inputs[self.text_input_name][0]), 67)

View File

@ -783,7 +783,7 @@ class ProcessorTesterMixin:
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Now test the ability to return dict
messages[0][0]["content"].append(
@ -845,7 +845,7 @@ class ProcessorTesterMixin:
return_dict=True,
padding=True,
)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Now test the ability to return dict
batched_messages[0][0]["content"].append(
@ -885,6 +885,7 @@ class ProcessorTesterMixin:
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
@ -982,7 +983,7 @@ class ProcessorTesterMixin:
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Add video URL for return dict and load with `num_frames` arg
messages[0][0]["content"][0] = {

View File

@ -226,6 +226,8 @@ SPECIAL_CASES_TO_ALLOW = {
"giou_loss_coefficient",
],
"GPTNeoXConfig": ["rotary_emb_base"],
"Gemma3Config": ["boi_token_index", "eoi_token_index"],
"Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"],
}