mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 17:48:22 +06:00

* Fix converter
* [Broken] Adds Gemma 3 to Hugging Face Transformers
* Consolidating Config and Processor params across impls
* Sorting out configuration parameters. Adds qk_norm before RoPE. Still not sure if RoPE is right.
* Additional plumbing for CausalLM and ConditionalGeneration variants
* incomplete draft of Orbax conversion script
* More complete checkpoint conversion
* Supporting Gemma 3 1B checkpoints
* Updating RoPE for multiple frequencies
* Adjustments to rotary embedder
* Proof of life for text-only operation
* Updating the conversion script to handle multimodal projection weights
* Fixing tet-only conversions
* Cleaner conversion script with multimodal support and a simpler processor
* Additional refatcors to the Gemma3Processor
* Simplified Processor to work over text representations
* Updated conversion script to join text and vision embeddings at converion time
* Logging for debugging
* Update src/transformers/models/gemma2/modeling_gemma2.py
Co-authored-by: Joshua Lochner <admin@xenova.com>
* Removed extraneous Config params
* Switching to fast tokenizer for checkpoint conversions
* isolating siglip for performance tetsing
* Minor changes for debugging tests against baselines
* Adding average pooling for soft tokens
* Updating processor code to enable simpler embedding interleaving for arbitrary number of images in prompts
* Updating conversion script for ShieldGemma 2 conversion compatibility
* Allow disable_compile to be provided as a kwarg
* Refresh from modular
* Updated conversion script and corrected sliding window
* Fix type mismatch in cache_position (#4)
* Fix dtype (#5)
* Fix type mismatch in cache_position
* Actually fix in the modular file
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
---------
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
* fixes for embedding table overflow and missing image_soft_token_mask from Gemma3Processor
* Adding 2D pooling for image embeddings
* Revert "Adding 2D pooling for image embeddings"
This reverts commit 65350cf531
.
* Gemma3 average pooling changed from 1D to 2D
* Major refactor to Gemma3MultimodalInputProjection
* Updating Gemm 3 Auto* registrations
* Add option to save Gemma 3 chat template with tokenizer during weights conversion
* Removing unused imports
* Moving out-of-vocab handling from Gemma3Processor to Gemma3ForConditionalGeneration
* Removing duplicate config property
* Removing final logit softcapping and 1-indexing of position ids
* Fixing image processor config and none --> None typo
* Fixing sliding window size for 1B
* Updating image_mean and image_std in Image Processor
* Attention masking changed to lower triangular
* Moving image special tokens to conversion script
* Mirror image processor defaults from conversion script into Gemma3ProcessorKwargs
* Remove special token variables from symbol space
* Moving image soft token mask computation from Gemma3Processor to Gemma3ForConditionalGeneration
* tie lm_head and embedding weights
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
* Correct tied weights in Gemma3CausalLM
* iterative bidirectional attention
* resolving merge conflicts
* Reverting to Gemma 2 HybridCache with sldiing window support and a sliding_window_pattern of 6
* Correcting RoPE scaling
* clean up first pass, dummy model geenration works
* final clean up before fixing tests
* causal lm test works, so fine
* Fix conversion
* Update src/transformers/models/gemma3/processing_gemma3.py
* model tests are happy
* processor tests are happy
* image processing tests added
* fixup
* Fix pre-processing in conversion
* Inputs merging
* Do not normalize vision embeddings
* Apply Ryan's (and team) changes to attention
* token type ids + mask
* template
* move embed scale, add rope scale, fix tests
* Add chat template to tokenizer
* Use prefix for causal model loading
* use existing code for sliding mask from gemma2
* self.embed_tokens already normalizes
* Correcting Gemma3TextConfig parameters in conversion script
* typo, modular overwrites my fixes
* enable device map for text model
* Conversion updates
* ultra nit: no einsums
* update image token
* copy deepcopy config + some docs
* add some test, still WIP
* Refactoring --include_chat_tempalte logic in converter
* Update src/transformers/models/gemma3/modular_gemma3.py
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
* Add eos tokens for instruct models
* dump so i can work on dgx
* Removing add_bos by default
* dump
* add fast im proc
* docs for PaS + fixup
* another fixup
* one more fixup
* fix tests
* Inverting prior BOS change
* ultra nit
* Reverting to Tokenizer saved with add_bos_token=True and chat template starting with BOS
* resize embeds, remove sqrt, add slow test outputs
* FA2 but quality is meh
* nit
* skip FA2, no idea what happened
* last bit for green CI
* please, green CI for docs
* T_T
* Fix for Gemma3 logits
* Support both options for system prompt
* Update src/transformers/models/gemma3/image_processing_gemma3_fast.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Update docs/source/en/model_doc/gemma3.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Docs updates now that assets are live
* Style fixes
---------
Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: Lysandre <hi@lysand.re>
137 lines
6.9 KiB
Python
137 lines
6.9 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
from typing import Optional
|
|
|
|
from transformers import Gemma3Processor, GemmaTokenizer
|
|
from transformers.testing_utils import get_tests_dir, require_vision
|
|
from transformers.utils import is_vision_available
|
|
|
|
from ...test_processing_common import ProcessorTesterMixin
|
|
|
|
|
|
if is_vision_available():
|
|
from transformers import Gemma3ImageProcessor
|
|
|
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
|
|
|
|
|
@require_vision
|
|
class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|
processor_class = Gemma3Processor
|
|
|
|
def setUp(self):
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
gemma3_image_processor_kwargs = {
|
|
"do_pan_and_scan": True,
|
|
"pan_and_scan_min_crop_size": 256,
|
|
"pan_and_scan_max_num_crops": 4,
|
|
"pan_and_scan_min_ratio_to_activate": 1.2,
|
|
}
|
|
image_processor = Gemma3ImageProcessor.from_pretrained(
|
|
"google/siglip-so400m-patch14-384", **gemma3_image_processor_kwargs
|
|
)
|
|
|
|
extra_special_tokens = {
|
|
"image_token": "<image_soft_token>",
|
|
"boi_token": "<start_of_image>",
|
|
"eoi_token": "<end_of_image>",
|
|
}
|
|
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens)
|
|
processor_kwargs = self.prepare_processor_dict()
|
|
processor = Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
|
|
processor.save_pretrained(self.tmpdirname)
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tmpdirname)
|
|
|
|
# TODO: raushan or arthur: add the real chat template
|
|
def prepare_processor_dict(self):
|
|
return {
|
|
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n", "image_seq_length": 3,
|
|
} # fmt: skip
|
|
|
|
# Override as VLMs need image tokens in prompts
|
|
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
|
if batch_size is None:
|
|
return "lower newer <start_of_image>"
|
|
|
|
if batch_size < 1:
|
|
raise ValueError("batch_size must be greater than 0")
|
|
|
|
if batch_size == 1:
|
|
return ["lower newer <start_of_image>"]
|
|
return ["lower newer <start_of_image>", "<start_of_image> upper older longer string"] + [
|
|
"<start_of_image> lower newer"
|
|
] * (batch_size - 2)
|
|
|
|
# Override as Gemma3 needs images to be an explicitly nested batch
|
|
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
|
"""This function prepares a list of PIL images for testing"""
|
|
images = super().prepare_image_inputs(batch_size)
|
|
if isinstance(images, (list, tuple)):
|
|
images = [[image] for image in images]
|
|
return images
|
|
|
|
def test_text_with_image_tokens(self):
|
|
image_processor = self.get_component("image_processor")
|
|
tokenizer = self.get_component("tokenizer")
|
|
|
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
|
text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!"
|
|
text_single_image = f"{processor.boi_token}Dummy text!"
|
|
text_no_image = "Dummy text!"
|
|
|
|
image = self.prepare_image_inputs()
|
|
|
|
# If text has no image tokens, iamge should be `None`
|
|
with self.assertRaises(ValueError):
|
|
_ = processor(text=text_no_image, images=image, return_tensors="np")
|
|
|
|
# We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text
|
|
with self.assertRaises(ValueError):
|
|
_ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np")
|
|
|
|
# The users is expected to be explicit about which image belong to which text by nesting the images list
|
|
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
|
|
out_batch_oneimage = processor(
|
|
text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np"
|
|
)
|
|
self.assertListEqual(
|
|
out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist()
|
|
)
|
|
|
|
def test_pan_and_scan(self):
|
|
processor_components = self.prepare_components()
|
|
processor_kwargs = self.prepare_processor_dict()
|
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
|
|
|
input_str = self.prepare_text_inputs()
|
|
image_input = self.prepare_image_inputs()
|
|
inputs = processor(
|
|
text=input_str,
|
|
images=image_input,
|
|
return_tensors="np",
|
|
do_pan_and_scan=True,
|
|
image_seq_length=2,
|
|
pan_and_scan_min_crop_size=10,
|
|
)
|
|
|
|
# base image + 4 crops
|
|
self.assertEqual(len(inputs[self.images_input_name]), 5)
|
|
self.assertEqual(len(inputs[self.text_input_name][0]), 67)
|