mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Gemma3 (#36658)
* Fix converter
* [Broken] Adds Gemma 3 to Hugging Face Transformers
* Consolidating Config and Processor params across impls
* Sorting out configuration parameters. Adds qk_norm before RoPE. Still not sure if RoPE is right.
* Additional plumbing for CausalLM and ConditionalGeneration variants
* incomplete draft of Orbax conversion script
* More complete checkpoint conversion
* Supporting Gemma 3 1B checkpoints
* Updating RoPE for multiple frequencies
* Adjustments to rotary embedder
* Proof of life for text-only operation
* Updating the conversion script to handle multimodal projection weights
* Fixing tet-only conversions
* Cleaner conversion script with multimodal support and a simpler processor
* Additional refatcors to the Gemma3Processor
* Simplified Processor to work over text representations
* Updated conversion script to join text and vision embeddings at converion time
* Logging for debugging
* Update src/transformers/models/gemma2/modeling_gemma2.py
Co-authored-by: Joshua Lochner <admin@xenova.com>
* Removed extraneous Config params
* Switching to fast tokenizer for checkpoint conversions
* isolating siglip for performance tetsing
* Minor changes for debugging tests against baselines
* Adding average pooling for soft tokens
* Updating processor code to enable simpler embedding interleaving for arbitrary number of images in prompts
* Updating conversion script for ShieldGemma 2 conversion compatibility
* Allow disable_compile to be provided as a kwarg
* Refresh from modular
* Updated conversion script and corrected sliding window
* Fix type mismatch in cache_position (#4)
* Fix dtype (#5)
* Fix type mismatch in cache_position
* Actually fix in the modular file
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
---------
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
* fixes for embedding table overflow and missing image_soft_token_mask from Gemma3Processor
* Adding 2D pooling for image embeddings
* Revert "Adding 2D pooling for image embeddings"
This reverts commit 65350cf531
.
* Gemma3 average pooling changed from 1D to 2D
* Major refactor to Gemma3MultimodalInputProjection
* Updating Gemm 3 Auto* registrations
* Add option to save Gemma 3 chat template with tokenizer during weights conversion
* Removing unused imports
* Moving out-of-vocab handling from Gemma3Processor to Gemma3ForConditionalGeneration
* Removing duplicate config property
* Removing final logit softcapping and 1-indexing of position ids
* Fixing image processor config and none --> None typo
* Fixing sliding window size for 1B
* Updating image_mean and image_std in Image Processor
* Attention masking changed to lower triangular
* Moving image special tokens to conversion script
* Mirror image processor defaults from conversion script into Gemma3ProcessorKwargs
* Remove special token variables from symbol space
* Moving image soft token mask computation from Gemma3Processor to Gemma3ForConditionalGeneration
* tie lm_head and embedding weights
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
* Correct tied weights in Gemma3CausalLM
* iterative bidirectional attention
* resolving merge conflicts
* Reverting to Gemma 2 HybridCache with sldiing window support and a sliding_window_pattern of 6
* Correcting RoPE scaling
* clean up first pass, dummy model geenration works
* final clean up before fixing tests
* causal lm test works, so fine
* Fix conversion
* Update src/transformers/models/gemma3/processing_gemma3.py
* model tests are happy
* processor tests are happy
* image processing tests added
* fixup
* Fix pre-processing in conversion
* Inputs merging
* Do not normalize vision embeddings
* Apply Ryan's (and team) changes to attention
* token type ids + mask
* template
* move embed scale, add rope scale, fix tests
* Add chat template to tokenizer
* Use prefix for causal model loading
* use existing code for sliding mask from gemma2
* self.embed_tokens already normalizes
* Correcting Gemma3TextConfig parameters in conversion script
* typo, modular overwrites my fixes
* enable device map for text model
* Conversion updates
* ultra nit: no einsums
* update image token
* copy deepcopy config + some docs
* add some test, still WIP
* Refactoring --include_chat_tempalte logic in converter
* Update src/transformers/models/gemma3/modular_gemma3.py
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
* Add eos tokens for instruct models
* dump so i can work on dgx
* Removing add_bos by default
* dump
* add fast im proc
* docs for PaS + fixup
* another fixup
* one more fixup
* fix tests
* Inverting prior BOS change
* ultra nit
* Reverting to Tokenizer saved with add_bos_token=True and chat template starting with BOS
* resize embeds, remove sqrt, add slow test outputs
* FA2 but quality is meh
* nit
* skip FA2, no idea what happened
* last bit for green CI
* please, green CI for docs
* T_T
* Fix for Gemma3 logits
* Support both options for system prompt
* Update src/transformers/models/gemma3/image_processing_gemma3_fast.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Docs updates now that assets are live
* Style fixes
---------
Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: Lysandre <hi@lysand.re>
This commit is contained in:
parent
81aa9b2e07
commit
50d3530aa0
203
docs/source/en/model_doc/gemma3.md
Normal file
203
docs/source/en/model_doc/gemma3.md
Normal file
@ -0,0 +1,203 @@
|
||||
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Gemma3
|
||||
|
||||
## Overview
|
||||
|
||||
The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
|
||||
|
||||
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq).
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
|
||||
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
|
||||
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
|
||||
- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
|
||||
- The text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
|
||||
- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it.
|
||||
|
||||
|
||||
### Image cropping for high resolution images
|
||||
|
||||
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
|
||||
|
||||
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
|
||||
|
||||
```python
|
||||
|
||||
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
|
||||
|
||||
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
do_pan_and_scan=True,
|
||||
).to(model.device)
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Usage Example
|
||||
|
||||
### Single-image Inference
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
```
|
||||
|
||||
### Multi-image Inference
|
||||
|
||||
```python
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url_cow},
|
||||
{"type": "image", "url": url_stop},
|
||||
{"type": "text", "text": "Are these two images identical?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
|
||||
```
|
||||
|
||||
### Text-only inference
|
||||
|
||||
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
|
||||
```python
|
||||
from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=100)
|
||||
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
print(text)
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Gemma3ImageProcessor
|
||||
|
||||
[[autodoc]] Gemma3ImageProcessor
|
||||
|
||||
## Gemma3ImageProcessorFast
|
||||
|
||||
[[autodoc]] Gemma3ImageProcessorFast
|
||||
|
||||
## Gemma3Processor
|
||||
|
||||
[[autodoc]] Gemma3Processor
|
||||
|
||||
## Gemma3TextConfig
|
||||
|
||||
[[autodoc]] Gemma3TextConfig
|
||||
|
||||
## Gemma3Config
|
||||
|
||||
[[autodoc]] Gemma3Config
|
||||
|
||||
## Gemma3TextModel
|
||||
|
||||
[[autodoc]] Gemma3TextModel
|
||||
- forward
|
||||
|
||||
## Gemma3ForCausalLM
|
||||
|
||||
[[autodoc]] Gemma3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Gemma3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Gemma3ForConditionalGeneration
|
||||
- forward
|
@ -474,6 +474,7 @@ _import_structure = {
|
||||
"models.fuyu": ["FuyuConfig"],
|
||||
"models.gemma": ["GemmaConfig"],
|
||||
"models.gemma2": ["Gemma2Config"],
|
||||
"models.gemma3": ["Gemma3Config", "Gemma3Processor", "Gemma3TextConfig"],
|
||||
"models.git": [
|
||||
"GitConfig",
|
||||
"GitProcessor",
|
||||
@ -1259,6 +1260,7 @@ else:
|
||||
_import_structure["models.emu3"].append("Emu3ImageProcessor")
|
||||
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
|
||||
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
|
||||
_import_structure["models.gemma3"].append("Gemma3ImageProcessor")
|
||||
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
|
||||
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
|
||||
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
|
||||
@ -1332,6 +1334,7 @@ else:
|
||||
_import_structure["models.deit"].append("DeiTImageProcessorFast")
|
||||
_import_structure["models.depth_pro"].append("DepthProImageProcessorFast")
|
||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
|
||||
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
|
||||
_import_structure["models.llava"].append("LlavaImageProcessorFast")
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
|
||||
@ -2452,6 +2455,14 @@ else:
|
||||
"Gemma2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gemma3"].extend(
|
||||
[
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3PreTrainedModel",
|
||||
"Gemma3TextModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.git"].extend(
|
||||
[
|
||||
"GitForCausalLM",
|
||||
@ -2554,6 +2565,7 @@ else:
|
||||
"GraniteMoePreTrainedModel",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.granitemoeshared"].extend(
|
||||
[
|
||||
"GraniteMoeSharedForCausalLM",
|
||||
@ -2561,7 +2573,6 @@ else:
|
||||
"GraniteMoeSharedPreTrainedModel",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.grounding_dino"].extend(
|
||||
[
|
||||
"GroundingDinoForObjectDetection",
|
||||
@ -5629,6 +5640,7 @@ if TYPE_CHECKING:
|
||||
from .models.fuyu import FuyuConfig
|
||||
from .models.gemma import GemmaConfig
|
||||
from .models.gemma2 import Gemma2Config
|
||||
from .models.gemma3 import Gemma3Config, Gemma3Processor, Gemma3TextConfig
|
||||
from .models.git import (
|
||||
GitConfig,
|
||||
GitProcessor,
|
||||
@ -6450,6 +6462,7 @@ if TYPE_CHECKING:
|
||||
FlavaProcessor,
|
||||
)
|
||||
from .models.fuyu import FuyuImageProcessor, FuyuProcessor
|
||||
from .models.gemma3 import Gemma3ImageProcessor
|
||||
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
|
||||
from .models.got_ocr2 import GotOcr2ImageProcessor
|
||||
from .models.grounding_dino import GroundingDinoImageProcessor
|
||||
@ -6535,6 +6548,7 @@ if TYPE_CHECKING:
|
||||
from .models.deit import DeiTImageProcessorFast
|
||||
from .models.depth_pro import DepthProImageProcessorFast
|
||||
from .models.detr import DetrImageProcessorFast
|
||||
from .models.gemma3 import Gemma3ImageProcessorFast
|
||||
from .models.got_ocr2 import GotOcr2ImageProcessorFast
|
||||
from .models.llava import LlavaImageProcessorFast
|
||||
from .models.llava_next import LlavaNextImageProcessorFast
|
||||
@ -7461,6 +7475,12 @@ if TYPE_CHECKING:
|
||||
Gemma2Model,
|
||||
Gemma2PreTrainedModel,
|
||||
)
|
||||
from .models.gemma3 import (
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3PreTrainedModel,
|
||||
Gemma3TextModel,
|
||||
)
|
||||
from .models.git import (
|
||||
GitForCausalLM,
|
||||
GitModel,
|
||||
|
@ -113,10 +113,10 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
|
||||
sp = self.sp
|
||||
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
||||
|
||||
# there is a missing token in the vocab. We have to do this to support merges
|
||||
# If "\t" is missing in the vocab, we have to do this to support merges
|
||||
# "<0x09>" is the bytefallback for `\t`
|
||||
vocab["\t"] = vocab.get("<0x09>")
|
||||
|
||||
if "\t" not in vocab:
|
||||
vocab["\t"] = vocab.get("<0x09>")
|
||||
merges = generate_merges(vocab, vocab_scores)
|
||||
return vocab, merges
|
||||
|
||||
@ -1296,12 +1296,14 @@ class GemmaConverter(SpmConverter):
|
||||
(self.original_tokenizer.eos_token, 0.0),
|
||||
(self.original_tokenizer.bos_token, 0.0),
|
||||
]
|
||||
for piece in proto.pieces[3:]:
|
||||
if piece.piece == "<0x09>":
|
||||
vocab += [("\t", piece.score)]
|
||||
else:
|
||||
vocab += [(piece.piece, piece.score)]
|
||||
# vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||
|
||||
# Older gemma tokenizers had a missing tab token, so we fix that here
|
||||
if not any(x[0] == "\t" for x in vocab):
|
||||
override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
|
||||
if override_index is not None:
|
||||
vocab[override_index] = ("\t", 0.0)
|
||||
|
||||
return vocab
|
||||
|
||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||
|
@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
for serialized_param_name, empty_param in state_dict.items():
|
||||
if serialized_param_name not in expected_keys:
|
||||
continue
|
||||
|
||||
# serialized_param_name is the raw, serialized name
|
||||
# fixed_param_name is the model's equivalent
|
||||
fixed_param_name, _ = model.rename_key(serialized_param_name)
|
||||
|
||||
if fixed_param_name not in expected_keys:
|
||||
continue
|
||||
|
||||
# we need to use serialized_param_name as file pointer is untouched
|
||||
if shard_file.endswith(".safetensors"):
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
|
@ -106,6 +106,7 @@ from . import (
|
||||
fuyu,
|
||||
gemma,
|
||||
gemma2,
|
||||
gemma3,
|
||||
git,
|
||||
glm,
|
||||
glpn,
|
||||
|
@ -124,6 +124,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuConfig"),
|
||||
("gemma", "GemmaConfig"),
|
||||
("gemma2", "Gemma2Config"),
|
||||
("gemma3", "Gemma3Config"),
|
||||
("gemma3_text", "Gemma3TextConfig"),
|
||||
("git", "GitConfig"),
|
||||
("glm", "GlmConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
@ -459,6 +461,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("fuyu", "Fuyu"),
|
||||
("gemma", "Gemma"),
|
||||
("gemma2", "Gemma2"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("gemma3_text", "Gemma3ForCausalLM"),
|
||||
("git", "GIT"),
|
||||
("glm", "GLM"),
|
||||
("glpn", "GLPN"),
|
||||
@ -748,6 +752,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("qwen2_audio_encoder", "qwen2_audio"),
|
||||
("clip_text_model", "clip"),
|
||||
("aria_text", "aria"),
|
||||
("gemma3_text", "gemma3"),
|
||||
("idefics3_vision", "idefics3"),
|
||||
("siglip_vision_model", "siglip"),
|
||||
("smolvlm_vision", "smolvlm"),
|
||||
|
@ -86,6 +86,7 @@ else:
|
||||
("flava", ("FlavaImageProcessor",)),
|
||||
("focalnet", ("BitImageProcessor",)),
|
||||
("fuyu", ("FuyuImageProcessor",)),
|
||||
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("glpn", ("GLPNImageProcessor",)),
|
||||
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
|
@ -118,6 +118,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("funnel", ("FunnelModel", "FunnelBaseModel")),
|
||||
("gemma", "GemmaModel"),
|
||||
("gemma2", "Gemma2Model"),
|
||||
("gemma3_text", "Gemma3TextModel"),
|
||||
("git", "GitModel"),
|
||||
("glm", "GlmModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
@ -338,6 +339,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("fnet", "FNetForPreTraining"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("funnel", "FunnelForPreTraining"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("gpt-sw3", "GPT2LMHeadModel"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
@ -518,6 +520,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
("gemma2", "Gemma2ForCausalLM"),
|
||||
("gemma3", "Gemma3ForCausalLM"),
|
||||
("gemma3_text", "Gemma3ForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("glm", "GlmForCausalLM"),
|
||||
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
||||
@ -824,6 +828,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
("emu3", "Emu3ForConditionalGeneration"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("git", "GitForCausalLM"),
|
||||
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
||||
("idefics", "IdeficsForVisionText2Text"),
|
||||
|
@ -63,6 +63,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3Processor"),
|
||||
("flava", "FlavaProcessor"),
|
||||
("fuyu", "FuyuProcessor"),
|
||||
("gemma3", "Gemma3Processor"),
|
||||
("git", "GitProcessor"),
|
||||
("got_ocr2", "GotOcr2Processor"),
|
||||
("grounding-dino", "GroundingDinoProcessor"),
|
||||
|
@ -215,6 +215,13 @@ else:
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"gemma3",
|
||||
(
|
||||
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
|
@ -41,7 +41,6 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -936,42 +935,23 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||
# which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
@ -994,19 +974,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if logits_to_keep is not None:
|
||||
model_inputs["logits_to_keep"] = logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...utils import logging
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaForCausalLM,
|
||||
@ -686,42 +686,23 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
):
|
||||
# Overwritten: has a special cache type, `HybridCache`
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
|
||||
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
|
||||
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
|
||||
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
|
||||
# which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
if logits_to_keep is None:
|
||||
_ = model_inputs.pop("logits_to_keep", None)
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
@ -744,19 +725,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if logits_to_keep is not None:
|
||||
model_inputs["logits_to_keep"] = logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
30
src/transformers/models/gemma3/__init__.py
Normal file
30
src/transformers/models/gemma3/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gemma3 import *
|
||||
from .image_processing_gemma3 import *
|
||||
from .image_processing_gemma3_fast import *
|
||||
from .modeling_gemma3 import *
|
||||
from .processing_gemma3 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
330
src/transformers/models/gemma3/configuration_gemma3.py
Normal file
330
src/transformers/models/gemma3/configuration_gemma3.py
Normal file
@ -0,0 +1,330 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
from ..siglip import SiglipVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma3Text-7B.
|
||||
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 262208):
|
||||
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Gemma3TextModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2304):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 9216):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 26):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
Scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
||||
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
|
||||
>>> configuration = Gemma3TextConfig()
|
||||
>>> # Initializing a model from the gemma3_text-7b style configuration
|
||||
>>> model = Gemma3TextModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
"""
|
||||
|
||||
model_type = "gemma3_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=262_208,
|
||||
hidden_size=2304,
|
||||
intermediate_size=9216,
|
||||
num_hidden_layers=26,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
head_dim=256,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
max_position_embeddings=131_072,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=1_000_000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
query_pre_attn_scalar=256,
|
||||
sliding_window=4096,
|
||||
final_logit_softcapping=None,
|
||||
attn_logit_softcapping=None,
|
||||
cache_implementation="hybrid",
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq=10_000.0,
|
||||
sliding_window_pattern=6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_activation = hidden_activation
|
||||
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||
self.sliding_window = sliding_window
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.cache_implementation = cache_implementation
|
||||
|
||||
self.rope_local_base_freq = rope_local_base_freq
|
||||
# For configuring HybridCache to work with 5:1 attention pattern
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.rope_scaling = rope_scaling
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class Gemma3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
||||
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
||||
|
||||
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
Custom vision config or dict.
|
||||
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
||||
The number of tokens per image embedding.
|
||||
boi_token_index (`int`, *optional*, defaults to 255999):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_index (`int`, *optional*, defaults to 256000):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
image_token_index (`int`, *optional*, defaults to 262144):
|
||||
The image token index to encode the image prompt.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
||||
|
||||
>>> # Initializing a Siglip-like vision config
|
||||
>>> vision_config = SiglipVisionConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 Text config
|
||||
>>> text_config = Gemma3TextConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
||||
>>> configuration = Gemma3Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the gemma-3-4b style configuration
|
||||
>>> model = Gemma3TextConfig(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma3"
|
||||
sub_configs = {
|
||||
"text_config": Gemma3TextConfig,
|
||||
"vision_config": SiglipVisionConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config: Optional[Gemma3TextConfig] = None,
|
||||
vision_config: Optional[SiglipVisionConfig] = None,
|
||||
mm_tokens_per_image: int = 256,
|
||||
boi_token_index: int = 255_999,
|
||||
eoi_token_index: int = 256_000,
|
||||
image_token_index: int = 262_144,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if text_config is None:
|
||||
text_config = Gemma3TextConfig()
|
||||
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Gemma3TextConfig(**text_config)
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
vision_config = SiglipVisionConfig()
|
||||
logger.info(
|
||||
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
||||
"to text tasks."
|
||||
)
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.mm_tokens_per_image = mm_tokens_per_image
|
||||
self.boi_token_index = boi_token_index
|
||||
self.eoi_token_index = eoi_token_index
|
||||
self.image_token_index = image_token_index
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Gemma3Config", "Gemma3TextConfig"]
|
@ -0,0 +1,592 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
|
||||
|
||||
python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \
|
||||
--variant='gemma3_4b' \
|
||||
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
|
||||
--checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \
|
||||
--output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" \
|
||||
--precision='bfloat16'
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import tree
|
||||
from absl import app, flags, logging
|
||||
from orbax import checkpoint as obc
|
||||
|
||||
from ...image_utils import PILImageResampling
|
||||
from ..gemma import GemmaTokenizerFast
|
||||
from . import (
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3ImageProcessor,
|
||||
Gemma3Processor,
|
||||
)
|
||||
from .configuration_gemma3 import (
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
|
||||
# ==== Internal Constants and Classes ====
|
||||
|
||||
|
||||
_CHAT_TEMPLATE = """{{ bos_token }}
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- if messages[0]['content'] is string -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
|
||||
{%- endif -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
{{ '<end_of_turn>\n' }}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<start_of_turn>model\n'}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
_DTYPES = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder"
|
||||
_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding"
|
||||
_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_"
|
||||
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
|
||||
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
|
||||
|
||||
_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
|
||||
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
|
||||
_TRANSFORMER_EMBEDDER = "transformer/embedder"
|
||||
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
|
||||
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
|
||||
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
|
||||
|
||||
_VISION_CONFIG = {
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"image_size": 896,
|
||||
"patch_size": 14,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"layer_norm_eps": 1e-6,
|
||||
"attention_dropout": 0.0,
|
||||
"vision_use_head": False,
|
||||
}
|
||||
|
||||
_VARIANT_GEMMA_3_1B = "gemma3_1b"
|
||||
_VARIANT_GEMMA_3_4B = "gemma3_4b"
|
||||
_VARIANT_GEMMA_3_12B = "gemma3_12b"
|
||||
_VARIANT_GEMMA_3_27B = "gemma3_27b"
|
||||
_VARIANTS = {
|
||||
_VARIANT_GEMMA_3_1B: Gemma3Config(
|
||||
text_config=Gemma3TextConfig(
|
||||
vocab_size=262_144,
|
||||
hidden_size=1152,
|
||||
intermediate_size=6 * 1152,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=26,
|
||||
num_key_value_heads=1,
|
||||
head_dim=256,
|
||||
sliding_window=512,
|
||||
rope_theta=1_000_000, # used for global RoPE only
|
||||
rope_local_base_freq=10_000,
|
||||
attn_logit_softcapping=None,
|
||||
query_pre_attn_scalar=256,
|
||||
max_position_embeddings=32_768,
|
||||
),
|
||||
vision_config=None,
|
||||
),
|
||||
_VARIANT_GEMMA_3_4B: Gemma3Config(
|
||||
text_config=Gemma3TextConfig(
|
||||
vocab_size=262_208,
|
||||
hidden_size=2560,
|
||||
intermediate_size=2560 * 8 // 2,
|
||||
num_attention_heads=8,
|
||||
head_dim=256,
|
||||
num_hidden_layers=34,
|
||||
num_key_value_heads=4,
|
||||
sliding_window=1024,
|
||||
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
|
||||
rope_theta=1_000_000,
|
||||
rope_local_base_freq=10_000,
|
||||
attn_logit_softcapping=None,
|
||||
query_pre_attn_scalar=256,
|
||||
),
|
||||
vision_config=_VISION_CONFIG,
|
||||
),
|
||||
_VARIANT_GEMMA_3_12B: Gemma3Config(
|
||||
text_config=Gemma3TextConfig(
|
||||
vocab_size=262_208,
|
||||
hidden_size=30 * 128,
|
||||
intermediate_size=30 * 128 * 8 // 2,
|
||||
num_attention_heads=16,
|
||||
head_dim=256,
|
||||
num_hidden_layers=48,
|
||||
num_key_value_heads=8,
|
||||
sliding_window=1024,
|
||||
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
|
||||
rope_theta=1_000_000,
|
||||
rope_local_base_freq=10_000,
|
||||
attn_logit_softcapping=None,
|
||||
query_pre_attn_scalar=256,
|
||||
),
|
||||
vision_config=_VISION_CONFIG,
|
||||
),
|
||||
_VARIANT_GEMMA_3_27B: Gemma3Config(
|
||||
text_config=Gemma3TextConfig(
|
||||
vocab_size=262_208,
|
||||
hidden_size=42 * 128,
|
||||
intermediate_size=42 * 128 * 8 // 2,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=62,
|
||||
num_key_value_heads=16,
|
||||
head_dim=128,
|
||||
sliding_window=1024,
|
||||
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
|
||||
rope_theta=1_000_000,
|
||||
rope_local_base_freq=10_000,
|
||||
attn_logit_softcapping=None,
|
||||
query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads)
|
||||
),
|
||||
vision_config=_VISION_CONFIG,
|
||||
),
|
||||
}
|
||||
|
||||
# ==== Flags ====
|
||||
|
||||
CHECKPOINT_PATH = flags.DEFINE_string(
|
||||
name="checkpoint_path",
|
||||
default=None,
|
||||
help="Path to the Orbax checkpoint.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
|
||||
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
|
||||
)
|
||||
|
||||
OUTPUT_PATH = flags.DEFINE_string(
|
||||
name="output_path",
|
||||
default=None,
|
||||
help="Path to store the HF checkpoint.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
PRECISION = flags.DEFINE_enum(
|
||||
name="precision",
|
||||
default=None,
|
||||
help="The floating point precision (aka dtype) of the model.",
|
||||
enum_values=set(_DTYPES.keys()),
|
||||
required=True,
|
||||
)
|
||||
|
||||
_TEXT_ONLY = flags.DEFINE_bool(
|
||||
name="text_only",
|
||||
default=False,
|
||||
help=(
|
||||
"If True, the model is loaded and saved as a Gemma3ForCausalLM, "
|
||||
"otherwise model saed as Gemma3ForConditionalGeneration."
|
||||
),
|
||||
)
|
||||
|
||||
TOKENIZER_PATH = flags.DEFINE_string(
|
||||
name="tokenizer_path",
|
||||
default=None,
|
||||
help="Path to the SentencePiece model file.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
_VARIANT = flags.DEFINE_enum(
|
||||
name="variant",
|
||||
default=_VARIANT_GEMMA_3_4B,
|
||||
help="The model variant to convert.",
|
||||
enum_values=set(_VARIANTS.keys()),
|
||||
)
|
||||
|
||||
|
||||
def convert_siglip_weight(
|
||||
config: SiglipVisionConfig,
|
||||
paths: Sequence[str],
|
||||
weights: np.ndarray,
|
||||
) -> tuple[str, np.ndarray]:
|
||||
path, prop = paths
|
||||
normalized_path: str = ""
|
||||
updated_weights: np.ndarray = None
|
||||
|
||||
if path == _SIGLIP_BASE:
|
||||
normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight"
|
||||
updated_weights = weights.reshape(-1, config.hidden_size)
|
||||
elif path == _SIGLIP_EMBEDDING:
|
||||
if prop == "kernel":
|
||||
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight"
|
||||
updated_weights = weights.transpose(3, 2, 0, 1)
|
||||
elif prop == "bias":
|
||||
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias"
|
||||
updated_weights = weights
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
||||
elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK):
|
||||
encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:]
|
||||
next_path_seperator_idx = encoder_block_path.find("/")
|
||||
layer_idx = encoder_block_path[:next_path_seperator_idx]
|
||||
encoder_block_path = encoder_block_path[next_path_seperator_idx:]
|
||||
normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
|
||||
|
||||
if encoder_block_path.startswith("/LayerNorm"):
|
||||
normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2"
|
||||
|
||||
if prop == "scale":
|
||||
normalized_path += ".weight"
|
||||
updated_weights = weights.transpose()
|
||||
elif prop == "bias":
|
||||
normalized_path += ".bias"
|
||||
updated_weights = weights
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
|
||||
elif encoder_block_path.startswith("/MlpBlock_0"):
|
||||
normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2"
|
||||
|
||||
if prop == "kernel":
|
||||
normalized_path += ".weight"
|
||||
updated_weights = weights.transpose()
|
||||
elif prop == "bias":
|
||||
normalized_path += ".bias"
|
||||
updated_weights = weights
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
||||
elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"):
|
||||
if encoder_block_path.endswith("/key"):
|
||||
normalized_path += ".self_attn.k_proj"
|
||||
elif encoder_block_path.endswith("/out"):
|
||||
normalized_path += ".self_attn.out_proj"
|
||||
elif encoder_block_path.endswith("/query"):
|
||||
normalized_path += ".self_attn.q_proj"
|
||||
elif encoder_block_path.endswith("/value"):
|
||||
normalized_path += ".self_attn.v_proj"
|
||||
else:
|
||||
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.")
|
||||
|
||||
if prop == "bias":
|
||||
normalized_path += ".bias"
|
||||
updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1)
|
||||
elif prop == "kernel":
|
||||
normalized_path += ".weight"
|
||||
updated_weights = weights.reshape(-1, config.hidden_size).transpose()
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
|
||||
else:
|
||||
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.")
|
||||
elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM:
|
||||
if prop == "scale":
|
||||
normalized_path = "vision_tower.vision_model.post_layernorm.weight"
|
||||
updated_weights = weights.transpose()
|
||||
elif prop == "bias":
|
||||
normalized_path = "vision_tower.vision_model.post_layernorm.bias"
|
||||
updated_weights = weights
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
|
||||
else:
|
||||
raise ValueError(f"Unexpected path `{path}`.")
|
||||
|
||||
if "vision" in normalized_path:
|
||||
print(normalized_path)
|
||||
return normalized_path, updated_weights
|
||||
|
||||
|
||||
def convert_transformer_weights(
|
||||
config: Gemma3TextConfig,
|
||||
paths: Sequence[str],
|
||||
weights: np.ndarray,
|
||||
) -> Iterator[tuple[str, np.ndarray]]:
|
||||
path, prop = paths
|
||||
|
||||
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
|
||||
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
|
||||
|
||||
converted_paths: list[str] = []
|
||||
converted_weights: list[Any] = []
|
||||
|
||||
attn_head_dim = config.num_attention_heads * config.head_dim
|
||||
kv_head_dim = config.num_key_value_heads * config.head_dim
|
||||
|
||||
if path == _TRANSFORMER_EMBEDDER:
|
||||
if prop == "input_embedding":
|
||||
# Tied to language_model.lm_head.weight, assigned at the end.
|
||||
converted_paths = ["language_model.model.embed_tokens.weight"]
|
||||
|
||||
if not _TEXT_ONLY.value:
|
||||
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
|
||||
pre_expansion_embeddings = weights
|
||||
mu = np.mean(pre_expansion_embeddings, axis=0)
|
||||
sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
|
||||
new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
|
||||
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
|
||||
|
||||
converted_weights = [weights]
|
||||
elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"):
|
||||
return zip([], [])
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
|
||||
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
|
||||
if _TEXT_ONLY.value:
|
||||
return zip([], [])
|
||||
|
||||
if path.endswith("/mm_input_projection"):
|
||||
converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("/mm_soft_embedding_norm"):
|
||||
converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
|
||||
converted_weights = [weights]
|
||||
else:
|
||||
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
|
||||
elif path == _TRANSFORMER_FINAL_NORM:
|
||||
converted_paths = ["language_model.model.norm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
|
||||
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
|
||||
next_path_seperator_idx = decoder_block_path.find("/")
|
||||
layer_idx = decoder_block_path[:next_path_seperator_idx]
|
||||
decoder_block_path = decoder_block_path[next_path_seperator_idx:]
|
||||
|
||||
base_path = f"language_model.model.layers.{layer_idx}"
|
||||
|
||||
if path.endswith("attn/attn_vec_einsum"):
|
||||
converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
|
||||
converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)]
|
||||
elif path.endswith("attn/_key_norm"):
|
||||
converted_paths = [f"{base_path}.self_attn.k_norm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("attn/kv_einsum"):
|
||||
converted_paths = [
|
||||
f"{base_path}.self_attn.k_proj.weight",
|
||||
f"{base_path}.self_attn.v_proj.weight",
|
||||
]
|
||||
k_proj_weights, v_proj_weights = weights
|
||||
converted_weights = [
|
||||
k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
|
||||
v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
|
||||
]
|
||||
elif path.endswith("attn/q_einsum"):
|
||||
converted_paths = [f"{base_path}.self_attn.q_proj.weight"]
|
||||
converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)]
|
||||
elif path.endswith("attn/_query_norm"):
|
||||
converted_paths = [f"{base_path}.self_attn.q_norm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("mlp/gating_einsum"):
|
||||
converted_paths = [
|
||||
f"{base_path}.mlp.gate_proj.weight",
|
||||
f"{base_path}.mlp.up_proj.weight",
|
||||
]
|
||||
gate_proj_weight, up_proj_weight = weights
|
||||
converted_weights = [gate_proj_weight, up_proj_weight]
|
||||
elif path.endswith("mlp/linear"):
|
||||
converted_paths = [f"{base_path}.mlp.down_proj.weight"]
|
||||
converted_weights = [weights.transpose()]
|
||||
elif path.endswith("post_attention_norm"):
|
||||
converted_paths = [f"{base_path}.post_attention_layernorm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("post_ffw_norm"):
|
||||
converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("pre_attention_norm"):
|
||||
converted_paths = [f"{base_path}.input_layernorm.weight"]
|
||||
converted_weights = [weights]
|
||||
elif path.endswith("pre_ffw_norm"):
|
||||
converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"]
|
||||
converted_weights = [weights]
|
||||
else:
|
||||
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
|
||||
else:
|
||||
raise ValueError(f"Unexpected path `{path}`.")
|
||||
|
||||
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
|
||||
raise ValueError(
|
||||
"The `converted_paths` and `converted_weights` should be the same "
|
||||
f"length. Got {cpl} and {cwl}, respectively, for {path}."
|
||||
)
|
||||
|
||||
return zip(converted_paths, converted_weights)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConversionResult:
|
||||
state_tree: dict[str, torch.Tensor]
|
||||
config: Gemma3Config
|
||||
|
||||
|
||||
def convert(
|
||||
checkpoint_path: str,
|
||||
config: Gemma3Config,
|
||||
target_dtype: torch.dtype,
|
||||
) -> ConversionResult:
|
||||
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
|
||||
checkpointer = obc.PyTreeCheckpointer()
|
||||
ckpt = checkpointer.restore(checkpoint_path)
|
||||
hf_tree: dict[str, torch.Tensor] = {}
|
||||
|
||||
def update_tree(path: str, weights: np.ndarray) -> None:
|
||||
torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype)
|
||||
logging.info(
|
||||
"%s converted shape=%s with dtype=%s",
|
||||
path,
|
||||
weights.shape,
|
||||
torch_tensor.dtype,
|
||||
)
|
||||
hf_tree[path] = torch_tensor
|
||||
|
||||
for paths, value in tree.flatten_with_path(ckpt):
|
||||
if paths[0].startswith("SigLiPFromPatches_"):
|
||||
if config.vision_config is None:
|
||||
continue
|
||||
|
||||
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
|
||||
update_tree(path, weights)
|
||||
else:
|
||||
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
|
||||
if config.vision_config is None:
|
||||
path = path[len("language_model.") :]
|
||||
|
||||
update_tree(path, weights)
|
||||
|
||||
if config.vision_config is None:
|
||||
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]
|
||||
else:
|
||||
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
|
||||
|
||||
return ConversionResult(state_tree=hf_tree, config=config)
|
||||
|
||||
|
||||
def main(*args):
|
||||
del args
|
||||
|
||||
variant = _VARIANT.value
|
||||
dtype = getattr(torch, PRECISION.value)
|
||||
config = _VARIANTS[variant]
|
||||
output_path = OUTPUT_PATH.value
|
||||
|
||||
if variant == _VARIANT_GEMMA_3_1B:
|
||||
flags.FLAGS.set_default(_TEXT_ONLY.name, True)
|
||||
|
||||
tokenizer = GemmaTokenizerFast(
|
||||
TOKENIZER_PATH.value,
|
||||
add_bos_token=True,
|
||||
extra_special_tokens={
|
||||
"image_token": "<image_soft_token>", # Should be ID=262_144
|
||||
"boi_token": "<start_of_image>", # Should be ID=255_999
|
||||
"eoi_token": "<end_of_image>", # Should be ID=256_000
|
||||
},
|
||||
)
|
||||
|
||||
if INCLUDE_CHAT_TEMPLATE.value:
|
||||
# Include chat template for CausalLM models
|
||||
tokenizer.chat_template = _CHAT_TEMPLATE
|
||||
config.eos_token_id = [1, 106]
|
||||
|
||||
if _TEXT_ONLY.value:
|
||||
config.vision_config = None
|
||||
tokenizer.save_pretrained(output_path)
|
||||
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
|
||||
del tokenizer
|
||||
else:
|
||||
image_processor = Gemma3ImageProcessor(
|
||||
image_seq_length=256,
|
||||
image_mean=(0.5,) * 3,
|
||||
image_std=(0.5,) * 3,
|
||||
size={"height": 896, "width": 896},
|
||||
resample=PILImageResampling.BILINEAR,
|
||||
)
|
||||
processor = Gemma3Processor(
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if INCLUDE_CHAT_TEMPLATE.value:
|
||||
# Duplicate so multimodal instruct models can also be used for CausalLM
|
||||
processor.chat_template = tokenizer.chat_template
|
||||
|
||||
processor.save_pretrained(output_path)
|
||||
logging.info("Saved Gemma3Processor for %s to %s", variant, output_path)
|
||||
del processor
|
||||
del tokenizer
|
||||
|
||||
logging.info("Gemma 3 (%s) configured as: %s", variant, config)
|
||||
logging.info("Converting Gemma 3 (%s) @ %s", variant, dtype)
|
||||
result = convert(CHECKPOINT_PATH.value, config, dtype)
|
||||
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
if config.vision_config is None:
|
||||
model = Gemma3ForCausalLM(config=config.text_config)
|
||||
else:
|
||||
model = Gemma3ForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(result.state_tree, assign=True, strict=True)
|
||||
model.config.torch_dtype = dtype
|
||||
logging.info("Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", variant, type(model).__name__)
|
||||
model.save_pretrained(output_path, safe_serialization=True)
|
||||
logging.info(
|
||||
"Saved Gemma 3 (%s) to SafeTensors in %s using %s",
|
||||
variant,
|
||||
output_path,
|
||||
type(model).__name__,
|
||||
)
|
||||
del model
|
||||
del result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
413
src/transformers/models/gemma3/image_processing_gemma3.py
Normal file
413
src/transformers/models/gemma3/image_processing_gemma3.py
Normal file
@ -0,0 +1,413 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Gemma3."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_nested_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
class Gemma3ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a SigLIP image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||
`do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
do_pan_and_scan (`bool`, *optional*):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
do_pan_and_scan: bool = None,
|
||||
pan_and_scan_min_crop_size: int = None,
|
||||
pan_and_scan_max_num_crops: int = None,
|
||||
pan_and_scan_min_ratio_to_activate: float = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_pan_and_scan = do_pan_and_scan
|
||||
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
|
||||
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
|
||||
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
|
||||
|
||||
def pan_and_scan(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds
|
||||
minumum allowed ratio.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
height, width = get_image_size(image)
|
||||
|
||||
# Square or landscape image.
|
||||
if width >= height:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if width / height < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
|
||||
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_w = max(2, num_crops_w)
|
||||
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
||||
num_crops_h = 1
|
||||
|
||||
# Portrait image.
|
||||
else:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if height / width < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_h = int(math.floor(height / width + 0.5))
|
||||
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_h = max(2, num_crops_h)
|
||||
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
||||
num_crops_w = 1
|
||||
|
||||
crop_size_w = int(math.ceil(width / num_crops_w))
|
||||
crop_size_h = int(math.ceil(height / num_crops_h))
|
||||
|
||||
# Don't apply PaS if crop size is too small.
|
||||
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
||||
return []
|
||||
|
||||
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
||||
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
||||
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image_crops = [
|
||||
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
||||
]
|
||||
else:
|
||||
image_crops = [
|
||||
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
||||
]
|
||||
|
||||
return image_crops
|
||||
|
||||
def _process_images_for_pan_and_scan(
|
||||
self,
|
||||
images: List[np.ndarray],
|
||||
do_pan_and_scan: bool,
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
pas_images_list = []
|
||||
num_crops = []
|
||||
for image in images:
|
||||
pas_images = self.pan_and_scan(
|
||||
image=image,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pas_images_list.extend([image] + pas_images)
|
||||
num_crops.append(len(pas_images))
|
||||
return pas_images_list, num_crops
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
do_pan_and_scan: bool = None,
|
||||
pan_and_scan_min_crop_size: int = None,
|
||||
pan_and_scan_max_num_crops: int = None,
|
||||
pan_and_scan_min_ratio_to_activate: float = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
do_pan_and_scan = do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
|
||||
pan_and_scan_min_crop_size = (
|
||||
pan_and_scan_min_crop_size if pan_and_scan_min_crop_size is not None else self.pan_and_scan_min_crop_size
|
||||
)
|
||||
pan_and_scan_max_num_crops = (
|
||||
pan_and_scan_max_num_crops if pan_and_scan_max_num_crops is not None else self.pan_and_scan_max_num_crops
|
||||
)
|
||||
pan_and_scan_min_ratio_to_activate = (
|
||||
pan_and_scan_min_ratio_to_activate
|
||||
if pan_and_scan_min_ratio_to_activate is not None
|
||||
else self.pan_and_scan_min_ratio_to_activate
|
||||
)
|
||||
|
||||
images_list = make_nested_list_of_images(images)
|
||||
|
||||
if not valid_images(images_list[0]):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
if do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||
|
||||
if do_rescale and is_scaled_image(images_list[0][0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||
|
||||
if do_pan_and_scan:
|
||||
images_list_and_num_crops = [
|
||||
self._process_images_for_pan_and_scan(
|
||||
images=images,
|
||||
do_pan_and_scan=do_pan_and_scan,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for images in images_list
|
||||
]
|
||||
images_list = [images for images, _ in images_list_and_num_crops]
|
||||
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
||||
else:
|
||||
num_crops = [[0] for images in images_list]
|
||||
|
||||
processed_images = []
|
||||
for images in images_list:
|
||||
for image in images:
|
||||
if do_resize:
|
||||
height, width = size["height"], size["width"]
|
||||
image = resize(
|
||||
image=image, size=(height, width), resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
processed_images.append(image)
|
||||
|
||||
data = {"pixel_values": processed_images, "num_crops": num_crops}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Gemma3ImageProcessor"]
|
387
src/transformers/models/gemma3/image_processing_gemma3_fast.py
Normal file
387
src/transformers/models/gemma3/image_processing_gemma3_fast.py
Normal file
@ -0,0 +1,387 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SigLIP."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorInitKwargs,
|
||||
DefaultFastImageProcessorPreprocessKwargs,
|
||||
get_size_dict,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_nested_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
|
||||
do_pan_and_scan: Optional[bool]
|
||||
pan_and_scan_min_crop_size: Optional[int]
|
||||
pan_and_scan_max_num_crops: Optional[int]
|
||||
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||
|
||||
|
||||
class Gemma3FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
|
||||
do_pan_and_scan: Optional[bool]
|
||||
pan_and_scan_min_crop_size: Optional[int]
|
||||
pan_and_scan_max_num_crops: Optional[int]
|
||||
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of Pan adn Scan cropping method.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
do_pan_and_scan (`bool`, *optional*):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
""",
|
||||
)
|
||||
class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 224, "width": 224}
|
||||
default_to_square = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_pan_and_scan = None
|
||||
pan_and_scan_min_crop_size = None
|
||||
pan_and_scan_max_num_crops = None
|
||||
pan_and_scan_min_ratio_to_activate = None
|
||||
valid_init_kwargs = Gemma3FastImageProcessorInitKwargs
|
||||
valid_preprocess_kwargs = Gemma3FastImageProcessorPreprocessKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_nested_list_of_images(images)
|
||||
|
||||
def _prepare_input_images(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_convert_rgb: bool = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> List["torch.Tensor"]:
|
||||
"""
|
||||
Prepare the input images for processing.
|
||||
"""
|
||||
batch_images = self._prepare_images_structure(images)
|
||||
process_image_fn = partial(
|
||||
self._process_image,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
# todo: yoni - check if we can parallelize this efficiently
|
||||
batch_processed_images = []
|
||||
for image_list in batch_images:
|
||||
processed_images = []
|
||||
for image in image_list:
|
||||
processed_images.append(process_image_fn(image))
|
||||
batch_processed_images.append(processed_images)
|
||||
|
||||
return batch_processed_images
|
||||
|
||||
def pan_and_scan(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
):
|
||||
"""
|
||||
Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds
|
||||
minumum allowed ratio.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
"""
|
||||
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
|
||||
# Square or landscape image.
|
||||
if width >= height:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if width / height < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
|
||||
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_w = max(2, num_crops_w)
|
||||
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
||||
num_crops_h = 1
|
||||
|
||||
# Portrait image.
|
||||
else:
|
||||
# Only apply PaS if the image is sufficiently exaggerated
|
||||
if height / width < pan_and_scan_min_ratio_to_activate:
|
||||
return []
|
||||
|
||||
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||
num_crops_h = int(math.floor(height / width + 0.5))
|
||||
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
|
||||
|
||||
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||
num_crops_h = max(2, num_crops_h)
|
||||
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
||||
num_crops_w = 1
|
||||
|
||||
crop_size_w = int(math.ceil(width / num_crops_w))
|
||||
crop_size_h = int(math.ceil(height / num_crops_h))
|
||||
|
||||
# Don't apply PaS if crop size is too small.
|
||||
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
||||
return []
|
||||
|
||||
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
||||
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
||||
|
||||
return [
|
||||
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
|
||||
]
|
||||
|
||||
def _process_images_for_pan_and_scan(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
do_pan_and_scan: bool,
|
||||
pan_and_scan_min_crop_size: int,
|
||||
pan_and_scan_max_num_crops: int,
|
||||
pan_and_scan_min_ratio_to_activate: float,
|
||||
):
|
||||
pas_images_list = []
|
||||
num_crops = []
|
||||
for image in images:
|
||||
pas_images = self.pan_and_scan(
|
||||
image=image,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
)
|
||||
pas_images_list.extend([image] + pas_images)
|
||||
num_crops.append(len(pas_images))
|
||||
return pas_images_list, num_crops
|
||||
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
do_pan_and_scan (`bool`, *optional*):
|
||||
Whether to apply `pan_and_scan` to images.
|
||||
pan_and_scan_min_crop_size (`int`, *optional*):
|
||||
Minimum size of each crop in pan and scan.
|
||||
pan_and_scan_max_num_crops (`int`, *optional*):
|
||||
Maximum number of crops per image in pan and scan.
|
||||
pan_and_scan_min_ratio_to_activate (`float`, *optional*):
|
||||
Minimum aspect ratio to activate pan and scan.
|
||||
""",
|
||||
)
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[Gemma3FastImageProcessorPreprocessKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys()
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_preprocess_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
|
||||
# Pop kwargs that need further processing or won't be used in _preprocess
|
||||
default_to_square = kwargs.pop("default_to_square")
|
||||
size = kwargs.pop("size")
|
||||
crop_size = kwargs.pop("crop_size")
|
||||
image_mean = kwargs.pop("image_mean")
|
||||
image_std = kwargs.pop("image_std")
|
||||
data_format = kwargs.pop("data_format")
|
||||
resample = kwargs.pop("resample")
|
||||
|
||||
# Make hashable for cache
|
||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
|
||||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
|
||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
||||
|
||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
||||
size=size,
|
||||
crop_size=crop_size,
|
||||
resample=resample,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
|
||||
device=images[0][0].device,
|
||||
do_resize=kwargs.get("do_resize"),
|
||||
do_center_crop=kwargs.get("do_center_crop"),
|
||||
do_rescale=kwargs.get("do_rescale"),
|
||||
rescale_factor=kwargs.get("rescale_factor"),
|
||||
do_normalize=kwargs.get("do_normalize"),
|
||||
return_tensors=kwargs.get("return_tensors"),
|
||||
)
|
||||
|
||||
return self._preprocess(
|
||||
images=images,
|
||||
size=size,
|
||||
crop_size=crop_size,
|
||||
interpolation=interpolation,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List[List["torch.Tensor"]],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
do_pan_and_scan: Optional[bool],
|
||||
pan_and_scan_min_crop_size: Optional[int],
|
||||
pan_and_scan_max_num_crops: Optional[int],
|
||||
pan_and_scan_min_ratio_to_activate: Optional[float],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
batch_num_crops = []
|
||||
|
||||
for image_list in images:
|
||||
if do_pan_and_scan:
|
||||
images_list, num_crops = self._process_images_for_pan_and_scan(
|
||||
images=image_list,
|
||||
do_pan_and_scan=do_pan_and_scan,
|
||||
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||
)
|
||||
else:
|
||||
num_crops = [[0] for images in images_list]
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.extend(processed_image_patches)
|
||||
batch_num_crops.extend(num_crops)
|
||||
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Gemma3ImageProcessorFast"]
|
1451
src/transformers/models/gemma3/modeling_gemma3.py
Normal file
1451
src/transformers/models/gemma3/modeling_gemma3.py
Normal file
File diff suppressed because it is too large
Load Diff
848
src/transformers/models/gemma3/modular_gemma3.py
Normal file
848
src/transformers/models/gemma3/modular_gemma3.py
Normal file
@ -0,0 +1,848 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
ModelOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ..bart.modeling_bart import BartScaledWordEmbedding
|
||||
from ..gemma2.configuration_gemma2 import Gemma2Config
|
||||
from ..gemma2.modeling_gemma2 import (
|
||||
Gemma2Attention,
|
||||
Gemma2ForCausalLM,
|
||||
Gemma2MLP,
|
||||
Gemma2Model,
|
||||
Gemma2PreTrainedModel,
|
||||
Gemma2RMSNorm,
|
||||
Gemma2RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
from ..siglip import SiglipVisionConfig
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/gemma-3-4b"
|
||||
_CONFIG_FOR_DOC = "Gemma3Config"
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
GEMMA3_INPUTS_DOCSTRING = ""
|
||||
|
||||
|
||||
class Gemma3TextConfig(Gemma2Config):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma3Text-7B.
|
||||
e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 262208):
|
||||
Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Gemma3TextModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2304):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 9216):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 26):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 256):
|
||||
Scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
final_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*):
|
||||
Scaling factor when applying tanh softcapping on the attention scores.
|
||||
cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3TextModel, Gemma3TextConfig
|
||||
>>> # Initializing a Gemma3Text gemma3_text-7b style configuration
|
||||
>>> configuration = Gemma3TextConfig()
|
||||
>>> # Initializing a model from the gemma3_text-7b style configuration
|
||||
>>> model = Gemma3TextModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings for local attention.
|
||||
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||
Pattern for the sliding window attention.
|
||||
"""
|
||||
|
||||
model_type = "gemma3_text"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=262_208,
|
||||
rope_theta=1_000_000.0,
|
||||
rope_scaling=None,
|
||||
rope_local_base_freq=10_000.0,
|
||||
sliding_window_pattern=6,
|
||||
max_position_embeddings=131_072,
|
||||
final_logit_softcapping=None,
|
||||
attn_logit_softcapping=None,
|
||||
**super_kwargs,
|
||||
):
|
||||
super().__init__(self, **super_kwargs)
|
||||
|
||||
self.rope_local_base_freq = rope_local_base_freq
|
||||
# For configuring HybridCache to work with 5:1 attention pattern
|
||||
self.sliding_window_pattern = sliding_window_pattern
|
||||
self.rope_scaling = rope_scaling
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class Gemma3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
||||
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
||||
|
||||
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
Custom vision config or dict.
|
||||
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
||||
The number of tokens per image embedding.
|
||||
boi_token_index (`int`, *optional*, defaults to 255999):
|
||||
The begin-of-image token index to wrap the image prompt.
|
||||
eoi_token_index (`int`, *optional*, defaults to 256000):
|
||||
The end-of-image token index to wrap the image prompt.
|
||||
image_token_index (`int`, *optional*, defaults to 262144):
|
||||
The image token index to encode the image prompt.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
||||
|
||||
>>> # Initializing a Siglip-like vision config
|
||||
>>> vision_config = SiglipVisionConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 Text config
|
||||
>>> text_config = Gemma3TextConfig()
|
||||
|
||||
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
||||
>>> configuration = Gemma3Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the gemma-3-4b style configuration
|
||||
>>> model = Gemma3TextConfig(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma3"
|
||||
sub_configs = {
|
||||
"text_config": Gemma3TextConfig,
|
||||
"vision_config": SiglipVisionConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config: Optional[Gemma3TextConfig] = None,
|
||||
vision_config: Optional[SiglipVisionConfig] = None,
|
||||
mm_tokens_per_image: int = 256,
|
||||
boi_token_index: int = 255_999,
|
||||
eoi_token_index: int = 256_000,
|
||||
image_token_index: int = 262_144,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
if text_config is None:
|
||||
text_config = Gemma3TextConfig()
|
||||
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Gemma3TextConfig(**text_config)
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
vision_config = SiglipVisionConfig()
|
||||
logger.info(
|
||||
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
||||
"to text tasks."
|
||||
)
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.mm_tokens_per_image = mm_tokens_per_image
|
||||
self.boi_token_index = boi_token_index
|
||||
self.eoi_token_index = eoi_token_index
|
||||
self.image_token_index = image_token_index
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gemma3CausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for Gemma3 causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class Gemma3TextScaledWordEmbedding(BartScaledWordEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3MLP(Gemma2MLP):
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class Gemma3RMSNorm(Gemma2RMSNorm):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
|
||||
def __init__(self, config: Gemma3TextConfig, device=None):
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
|
||||
class Gemma3Attention(Gemma2Attention):
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
|
||||
|
||||
super().__init__()
|
||||
self.sliding_window = config.sliding_window if self.is_sliding else None
|
||||
|
||||
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
|
||||
"Falling back to eager attention. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask.to(query_states),
|
||||
dropout=self.attention_dropout if self.training else 0.0,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.is_sliding = self.self_attn.is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings_global: torch.Tensor,
|
||||
position_embeddings_local: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(attention_mask.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# apply global RoPE to non-sliding layer only
|
||||
if self.self_attn.is_sliding:
|
||||
position_embeddings = position_embeddings_local
|
||||
else:
|
||||
position_embeddings = position_embeddings_global
|
||||
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
GEMMA3_START_DOCSTRING = None
|
||||
|
||||
|
||||
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
|
||||
base_model_prefix = "language_model"
|
||||
_no_split_modules = [
|
||||
"Gemma3DecoderLayer",
|
||||
"SiglipVisionEmbeddings",
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.text_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Gemma3TextModel(Gemma2Model):
|
||||
config_class = Gemma3TextConfig
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
|
||||
self.embed_tokens = Gemma3TextScaledWordEmbedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
|
||||
)
|
||||
|
||||
# TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas
|
||||
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
|
||||
config = copy.deepcopy(config)
|
||||
config.rope_theta = config.rope_local_base_freq
|
||||
config.rope_scaling = {"rope_type": "default"}
|
||||
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
output = BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class Gemma3ForCausalLM(Gemma2ForCausalLM):
|
||||
config_class = Gemma3TextConfig
|
||||
base_model_prefix = "language_model"
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__(config)
|
||||
self.model = Gemma3TextModel(config)
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(nn.Module):
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
|
||||
self.mm_input_projection_weight = nn.Parameter(
|
||||
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
|
||||
)
|
||||
|
||||
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
||||
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
||||
)
|
||||
|
||||
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
|
||||
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
||||
|
||||
def forward(self, vision_outputs: torch.Tensor):
|
||||
batch_size, _, seq_length = vision_outputs.shape
|
||||
|
||||
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||
)
|
||||
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||
|
||||
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||
|
||||
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||
|
||||
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
image_features = self.multi_modal_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
input_tensor,
|
||||
is_training: bool = False,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
return attention_mask
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted
|
||||
# form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
# Apply bidirectional mask on images if token type ids are provided
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
token_type_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Gemma3Config",
|
||||
"Gemma3TextConfig",
|
||||
"Gemma3PreTrainedModel", # noqa: F822
|
||||
"Gemma3TextModel",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
]
|
172
src/transformers/models/gemma3/processing_gemma3.py
Normal file
172
src/transformers/models/gemma3/processing_gemma3.py
Normal file
@ -0,0 +1,172 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import to_py_obj
|
||||
|
||||
|
||||
class Gemma3ImagesKwargs(ImagesKwargs):
|
||||
do_pan_and_scan: Optional[bool]
|
||||
pan_and_scan_min_crop_size: Optional[int]
|
||||
pan_and_scan_max_num_crops: Optional[int]
|
||||
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||
do_convert_rgb: Optional[bool]
|
||||
|
||||
|
||||
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Gemma3ImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_pan_and_scan": False,
|
||||
"pan_and_scan_min_crop_size": 256,
|
||||
"pan_and_scan_max_num_crops": 4,
|
||||
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Gemma3Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template", "image_seq_length"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor,
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
image_seq_length: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_seq_length = image_seq_length
|
||||
self.image_token_id = tokenizer.image_token_id
|
||||
self.boi_token = tokenizer.boi_token
|
||||
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
||||
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
||||
|
||||
super().__init__(
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
videos=None,
|
||||
audio=None,
|
||||
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
if text is None and images is None:
|
||||
raise ValueError("Provide at least one of `text` or `images`.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Gemma3ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_inputs = {}
|
||||
if images is not None:
|
||||
batched_images = make_nested_list_of_images(images)
|
||||
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
||||
|
||||
# Create empty text to be replaced with placeholders
|
||||
if not text:
|
||||
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
||||
|
||||
if len(batched_images) != len(text):
|
||||
raise ValueError(
|
||||
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||
)
|
||||
|
||||
# Replace image tokens by the full expanded sequence
|
||||
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||
text_with_crops = text
|
||||
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
||||
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
||||
|
||||
if len(images) != len(image_indexes):
|
||||
raise ValueError(
|
||||
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
||||
)
|
||||
|
||||
# Insert additional image tokens for Pan-and-Scan crops
|
||||
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
||||
if num:
|
||||
formatted_image_text = (
|
||||
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
||||
+ " ".join([self.boi_token] * num)
|
||||
)
|
||||
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
||||
text_with_crops[batch_idx] = prompt
|
||||
|
||||
# Expand placeholder image tokens to the full image token sequence
|
||||
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
||||
|
||||
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
||||
array_ids = np.array(text_inputs["input_ids"])
|
||||
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
||||
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["Gemma3Processor"]
|
@ -477,11 +477,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -490,8 +485,16 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -506,10 +509,16 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
|
@ -4609,6 +4609,34 @@ class Gemma2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma3ForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma3ForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma3PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Gemma3TextModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GitForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -58,6 +58,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class Gemma3ImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class GotOcr2ImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
@ -289,6 +289,13 @@ class FuyuProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class Gemma3ImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class GLPNFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
@ -124,6 +124,7 @@ VLM_CLASS_NAMES = [
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"gemma3",
|
||||
]
|
||||
|
||||
|
||||
|
@ -353,7 +353,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_Gemma_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
@ -153,6 +153,13 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
|
||||
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
|
||||
)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
0
tests/models/gemma3/__init__.py
Normal file
0
tests/models/gemma3/__init__.py
Normal file
229
tests/models/gemma3/test_image_processing_gemma3.py
Normal file
229
tests/models/gemma3/test_image_processing_gemma3.py
Normal file
@ -0,0 +1,229 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import Gemma3ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Gemma3ImageProcessorFast
|
||||
|
||||
|
||||
class Gemma3ImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=IMAGENET_STANDARD_MEAN,
|
||||
image_std=IMAGENET_STANDARD_STD,
|
||||
do_convert_rgb=True,
|
||||
do_pan_and_scan=True,
|
||||
pan_and_scan_min_crop_size=10,
|
||||
pan_and_scan_max_num_crops=2,
|
||||
pan_and_scan_min_ratio_to_activate=1.2,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_pan_and_scan = do_pan_and_scan
|
||||
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
|
||||
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
|
||||
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"do_pan_and_scan": self.do_pan_and_scan,
|
||||
"pan_and_scan_min_crop_size": self.pan_and_scan_min_crop_size,
|
||||
"pan_and_scan_max_num_crops": self.pan_and_scan_max_num_crops,
|
||||
"pan_and_scan_min_ratio_to_activate": self.pan_and_scan_min_ratio_to_activate,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Gemma3ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Gemma3ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Gemma3
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = Gemma3ImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pan_and_scan"))
|
||||
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_crop_size"))
|
||||
self.assertTrue(hasattr(image_processing, "pan_and_scan_max_num_crops"))
|
||||
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_ratio_to_activate"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=84)
|
||||
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
|
||||
|
||||
def test_pan_and_scan(self):
|
||||
"""
|
||||
Enables Pan and Scan path by choosing the correct input image resolution. If you are changing
|
||||
image processor attributes for PaS, please update this test.
|
||||
"""
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
"""This function prepares a list of PIL images"""
|
||||
image_inputs = [np.random.randint(255, size=(3, 300, 600), dtype=np.uint8)] * 3
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
# Test not batched input, 3 images because we have base image + 2 crops
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (3, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched, 9 images because we have base image + 2 crops per each item
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (9, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
520
tests/models/gemma3/test_modeling_gemma3.py
Normal file
520
tests/models/gemma3/test_modeling_gemma3.py
Normal file
@ -0,0 +1,520 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma3 model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Gemma3Config,
|
||||
Gemma3TextConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...models.gemma.test_modeling_gemma import GemmaModelTester
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3Processor,
|
||||
Gemma3TextModel,
|
||||
)
|
||||
|
||||
|
||||
class Gemma3ModelTester(GemmaModelTester):
|
||||
if is_torch_available():
|
||||
config_class = Gemma3TextConfig
|
||||
model_class = Gemma3TextModel
|
||||
for_causal_lm_class = Gemma3ForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Gemma3TextModel, Gemma3ForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Gemma3ForCausalLM,) if is_torch_available() else ()
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
_is_stateful = True
|
||||
model_split_percents = [0.5, 0.6]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Gemma3ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
|
||||
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
|
||||
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
|
||||
)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3Vision2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
mm_tokens_per_image=2,
|
||||
image_token_index=1,
|
||||
boi_token_index=2,
|
||||
eoi_token_index=3,
|
||||
seq_length=25,
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"use_labels": True,
|
||||
"image_size": 20,
|
||||
"patch_size": 5,
|
||||
"num_channels": 3,
|
||||
"is_training": True,
|
||||
"hidden_size": 32,
|
||||
"num_key_value_heads": 1,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"dropout": 0.1,
|
||||
"attention_dropout": 0.1,
|
||||
"initializer_range": 0.02,
|
||||
},
|
||||
use_cache=False,
|
||||
):
|
||||
self.parent = parent
|
||||
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
|
||||
self.mm_tokens_per_image = mm_tokens_per_image
|
||||
self.image_token_index = image_token_index
|
||||
self.boi_token_index = boi_token_index
|
||||
self.eoi_token_index = eoi_token_index
|
||||
self.llm_tester = Gemma3ModelTester(self.parent)
|
||||
self.text_config = self.llm_tester.get_config()
|
||||
self.vision_config = vision_config
|
||||
self.seq_length = seq_length
|
||||
self.pad_token_id = self.text_config.pad_token_id
|
||||
|
||||
self.num_hidden_layers = self.text_config.num_hidden_layers
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
self.hidden_size = self.text_config.hidden_size
|
||||
self.num_attention_heads = self.text_config.num_attention_heads
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_channels = vision_config["num_channels"]
|
||||
self.image_size = vision_config["image_size"]
|
||||
self.encoder_seq_length = seq_length
|
||||
self.use_cache = use_cache
|
||||
|
||||
def get_config(self):
|
||||
return Gemma3Config(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
image_token_index=self.image_token_index,
|
||||
boi_token_index=self.boi_token_index,
|
||||
eoi_token_index=self.eoi_token_index,
|
||||
mm_tokens_per_image=self.mm_tokens_per_image,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.vision_config["num_channels"],
|
||||
self.vision_config["image_size"],
|
||||
self.vision_config["image_size"],
|
||||
]
|
||||
)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||
|
||||
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
|
||||
# do not change this unless you modified image size or patch size
|
||||
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||
input_ids[:, :1] = config.image_token_index
|
||||
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
token_type_ids[input_ids == config.image_token_index] = 1
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
_is_stateful = True
|
||||
model_split_percents = [0.5, 0.6]
|
||||
|
||||
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Gemma3Vision2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
|
||||
|
||||
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
|
||||
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
|
||||
)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
|
||||
)
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
||||
)
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
# @require_read_token
|
||||
class Gemma3IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left")
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||
self.messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_4b_bf16(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
self.messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_batch(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
messages_2 = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||
},
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "Are these images identical?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.messages, messages_2],
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = [
|
||||
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
||||
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
||||
] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_4b_multiimage(self):
|
||||
model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What do you see here?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_1b_text_only(self):
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
|
||||
model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
# TODO: raushan FA2 generates gibberish for no reason, check later
|
||||
# @require_flash_attn
|
||||
# @require_torch_gpu
|
||||
# @mark.flash_attn_test
|
||||
# def test_model_4b_flash_attn(self):
|
||||
# model_id = "gg-hf-g/gemma-3-4b-it"
|
||||
#
|
||||
# model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
# model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
# ).to(torch_device)
|
||||
#
|
||||
# inputs = self.processor.apply_chat_template(
|
||||
# self.messages,
|
||||
# tokenize=True,
|
||||
# return_dict=True,
|
||||
# return_tensors="pt",
|
||||
# add_generation_prompt=True,
|
||||
# ).to(torch_device)
|
||||
#
|
||||
# output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||
# output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||
#
|
||||
# EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip
|
||||
# self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
136
tests/models/gemma3/test_processing_gemma3.py
Normal file
136
tests/models/gemma3/test_processing_gemma3.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Gemma3Processor, GemmaTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import Gemma3ImageProcessor
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_vision
|
||||
class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Gemma3Processor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
gemma3_image_processor_kwargs = {
|
||||
"do_pan_and_scan": True,
|
||||
"pan_and_scan_min_crop_size": 256,
|
||||
"pan_and_scan_max_num_crops": 4,
|
||||
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||
}
|
||||
image_processor = Gemma3ImageProcessor.from_pretrained(
|
||||
"google/siglip-so400m-patch14-384", **gemma3_image_processor_kwargs
|
||||
)
|
||||
|
||||
extra_special_tokens = {
|
||||
"image_token": "<image_soft_token>",
|
||||
"boi_token": "<start_of_image>",
|
||||
"eoi_token": "<end_of_image>",
|
||||
}
|
||||
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens)
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# TODO: raushan or arthur: add the real chat template
|
||||
def prepare_processor_dict(self):
|
||||
return {
|
||||
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n", "image_seq_length": 3,
|
||||
} # fmt: skip
|
||||
|
||||
# Override as VLMs need image tokens in prompts
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
return "lower newer <start_of_image>"
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
|
||||
if batch_size == 1:
|
||||
return ["lower newer <start_of_image>"]
|
||||
return ["lower newer <start_of_image>", "<start_of_image> upper older longer string"] + [
|
||||
"<start_of_image> lower newer"
|
||||
] * (batch_size - 2)
|
||||
|
||||
# Override as Gemma3 needs images to be an explicitly nested batch
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
images = super().prepare_image_inputs(batch_size)
|
||||
if isinstance(images, (list, tuple)):
|
||||
images = [[image] for image in images]
|
||||
return images
|
||||
|
||||
def test_text_with_image_tokens(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!"
|
||||
text_single_image = f"{processor.boi_token}Dummy text!"
|
||||
text_no_image = "Dummy text!"
|
||||
|
||||
image = self.prepare_image_inputs()
|
||||
|
||||
# If text has no image tokens, iamge should be `None`
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(text=text_no_image, images=image, return_tensors="np")
|
||||
|
||||
# We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np")
|
||||
|
||||
# The users is expected to be explicit about which image belong to which text by nesting the images list
|
||||
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
|
||||
out_batch_oneimage = processor(
|
||||
text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np"
|
||||
)
|
||||
self.assertListEqual(
|
||||
out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist()
|
||||
)
|
||||
|
||||
def test_pan_and_scan(self):
|
||||
processor_components = self.prepare_components()
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="np",
|
||||
do_pan_and_scan=True,
|
||||
image_seq_length=2,
|
||||
pan_and_scan_min_crop_size=10,
|
||||
)
|
||||
|
||||
# base image + 4 crops
|
||||
self.assertEqual(len(inputs[self.images_input_name]), 5)
|
||||
self.assertEqual(len(inputs[self.text_input_name][0]), 67)
|
@ -783,7 +783,7 @@ class ProcessorTesterMixin:
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
|
||||
|
||||
# Now test the ability to return dict
|
||||
messages[0][0]["content"].append(
|
||||
@ -845,7 +845,7 @@ class ProcessorTesterMixin:
|
||||
return_dict=True,
|
||||
padding=True,
|
||||
)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
|
||||
|
||||
# Now test the ability to return dict
|
||||
batched_messages[0][0]["content"].append(
|
||||
@ -885,6 +885,7 @@ class ProcessorTesterMixin:
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=50,
|
||||
)
|
||||
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||
@ -982,7 +983,7 @@ class ProcessorTesterMixin:
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
|
||||
|
||||
# Add video URL for return dict and load with `num_frames` arg
|
||||
messages[0][0]["content"][0] = {
|
||||
|
@ -226,6 +226,8 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"GPTNeoXConfig": ["rotary_emb_base"],
|
||||
"Gemma3Config": ["boi_token_index", "eoi_token_index"],
|
||||
"Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"],
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user