mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 03:28:22 +06:00

* Gemma 3n * initial commit of Gemma 3n scaffold * Fixing param pass through on Gemm3p5RMSNorm * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma3p5 overall and text config with vision and audio config placeholders (#3) * Adding gemma3p5 text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3n (#3) * Initial Gemm3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * regenerating modeling file after syncing to HEAD * Use torch.std(..., unbiased=False) for activation sparsity (#8) * Refactoring to a single QVK Norm (#13) * AltUp: support scale_corrected_output (#14) * Converts einsums to nn.Linear (#7) * Converts einsums to nn.Linear * Removing unused variables * Aligning SharedKVCache with HybridCache (#11) * Alinging SharedKVStore with HybridCache * Remove KVStore. Refactor apply_rotary_pos_emb for sharing * Addressing review comments * Supporting split modality embeddings in Gemma3n (#10) * Adding the Embedder class * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Addressing review comments, adding audio embedding layers, integrating embedder with the remaining architecture, adding a forward method for conditional generation * Apply suggestions from code review Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Addressing review comments, prop drilling audio and vision configs to the text config * Removing TODO's that have been addressed * Simplify Embedder init and add audio embeddings * Embeddings refactor. Adds Gemma3nAudioEmbedder and Gemma3nVisionEmbedder * Refactoring vision and audio embeddings into ConditionalGeneration model --------- Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating attention mask for Gemma 3.5 (#15) * xxx_token_index to xxx_token_id * remvoing deprecated last_cache_position * Removing references to SigLIP * Always init per-layer inputs * Using torch.finfo().min for epsilon_tensor * Gemma3nDecoderLayer inherits from Gemma3DecoderLayer. Remove gating lambdas * fix modular GEMMA3N_INPUTS_DOCSTRING * Gemma3nAttention inherits from Gemma3Attention * Modular inheritance fixes * CausalLM conversion script for 4B model (#16) * Add Gemma3n Audio Encoder (#6) * initial commit of Gemma 3.5 scaffold * Fixing param pass through on Gemm3nRMSNorm * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma3n overall and text config with vision and audio config placeholders (#3) * Adding gemma3n text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3.5 (#3) * Initial Gemm3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right Gemma 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3nAudioEncoder with nn.Sequential * Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * CausalLM conversion script for 4B model * inv_timescales to non-persistent buffer * Addressing audio encoder Attention feedback * Addressing Gemma3nAudioSSCPConvBlock feedback * Addressing Gemma3nAudioConformerAttention feedback * Addressing padding feedback * Weights conversion loads audio state dict * Always use vision_config so saving works * Token id updates for configs * Stubs for interleaving audio embs * Addressing reviewer feedback --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> * Fixing cache access error * Removing duplicate code from a bad merge * Gemma 3n Text + Vision Part 1 (#17) * testing utilities for numerics comparisons * Corrected einsum to nn.Linear weights conversion * Inherit scaled word embs from Gemma3 not Bart * Fixing transposes for collapsed linears * More transpose fixes * numpy api fix * RMSNorm: Explicit kwargs, scale_shift=0.0 when with_scale=True * Force AltUp to float32 * Updating debugging script for AudioEncoder debugging * Support divide_weight_by_sqrt_fan_in from JAX for per-layer inputs * Correcting attention einsum conversions * RMSNorm in type of x * Fixing douplicate laurel norm/gating * KV sharing using the right previous indices * Refactor kv shared index computation. Correct frac_shared_layers * Use num_shared_layers instead of inferring from a fraction * fixing a bug for logging * Fix shared data_ptrs in altup inits * rope: adjust proj -> norm -> rope to preserve computation (#20) * rope: adjust proj -> norm -> rope to preserve computation * Removing some breaking language model fluff in ConditionalGeneration * Consolidate query_states transforms --------- Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Vectorize the loops in AltUp (#19) * Vectorize the loops in AltUp * fix typo * Expanding to support batched inputs * remove extra debug script * Fix AltUp.forward --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Add 'scale_shift=0.0, with_scale=True' to the final norm in TextModel * Convert norm to 1/sqrt (#21) * Convert norm to 1/sqrt * Scale shift change per Phil's rec * Adding default activation sparsity * Fixing 2B config in weights conversion script * Fixing RMSNorm parameters - adding scale_shift and with_scale * Correcting query pre-attention scaling * Adding query_rescale_scalar to text config * Adding layer_idx to MLP * Permafix for input_layernorm * Use 1/sqrt instead of rsqrt in DecoderLayer * Fix o_proj conversion * Conversion script update for vision encoder * Removing logging for debugging timm model * Fixing bugs in Gemma3nForConditionalGeneration for text generation * Generating the modeling_gemma3n.py file * Removing the addition of an erroneous line in the modeling file * Adding gemma3n text model to modeling_auto * Bugfix: Updating the interleaving of inputs_embeds and vision_embeds * Updating the modeling file with the latest bugfix changes * Updating models/auto for Gemma 3n * using AutoTokenizer in forward test * Adding processing_gemma3n.py * Gemma 3n configured for AutoModel. Conversion script updated. * Removing errant merge artifacts --------- Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> * Removing errant debugging statements from Gemma 3 * Gemma3n audio model (#18) * testing utilities for numerics comparisons * Implement CumulativeGroupNorm and add to SubSampleConvProjection and SSCPConvBlock * Add audio version of forward script based on RyanMullins' implementation * Updating to match encoder tests. WIP: config question needs resolving * Updates to audio classes to enable end-to-end running * Removing vestigial classes, cleaning up print statements * Adding SiLU / Swish to audio conformer feed forward block * Shifted Gemma3p5Audio naming prefix to Gemma3NanoAudio * Adding outputs to audio test * Fixes to padding in SSCP and 1D convolution, align RMS Norm with wider model * Update forward test to load from local weights * Update conversion to process / output audio layers * Update __all__ to export audio encoder * AutoModel registration for Gemma 3n Audio * Use AutoModel for ConditionalGeneration.audio_tower * Fixing input_proj_linear transpose * Fixing Gemma3NanoAudioConformerAttention.post conversion * Fixing Gemma3NanoAudioSSCPConvBlock.conv weights conversion * Correcting indentation issue on Gemma3p5RMSNorm --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Text + Vision Part 2 (#23) * Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3p5.py * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Updating configs for the 2B variant in the conversion script * Using final generation config in conversion script --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Audio Integration (#12) * initial commit of Gemma 3n scaffold * Fixing param pass through on Gemm3nRMSNorm * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma 3n overall and text config with vision and audio config placeholders (#3) * Adding Gemma 3n text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3n (#3) * Initial Gemma3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3n * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3nAudioEncoder with nn.Sequential * Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * Converting sl.Frontend to FeatureExtractor * Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3n.py * Update modular Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Draft of audio data in chat template * Removing image processing. Using SigLIP instead. * Audio input going end-to-end * Fixing dtype issues in audio encoder * x-lib formatting consistency * Adding example data * Save preprocessor_config.json from conversion script * Instrumentaiton for debugging * Additional instrumentation for preprocessing debugging * Updates to preprocessor, padding; produces correct end-to-end results on sample * Tackling configuraiton TODOs * Start of feature extractor refatcor * Adds Numpy version of USM extractor, removes Torch version and dependencies * Fixing AltUp.correct coef permute * Supporting batches of single audio segment inputs * Docstrings updates for config * In-lining audio feature extraction * Adjustments to conversion script and smoke test script --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: pculliton <phillipculliton@gmail.com> * Gemma 3n renaming * Removing test data and utilities * Renaming test files * Gemma 3n refactor * Fix tokenizer config in conversion script * Address reviewer feedback * FeatureExtractor returns float32 by default * Adding basic tests for audio, and input name for audio encoder * Audio integration test, updates to model_id for other integration tests * Use scales for q and k norms (#26) * Update audio integration test to use HF dataset * Reviewer feedback * Expand embedding table to full vocab size in weights conversion * Mix-n-match MatFormers for Gemma 3n (#25) * Remove in-place operations (#30) * chore: removing inplace ops * remove [tensor] * n pattern * chore: reviewer feedback in AudioEncoder and AltUp * More grad clipping * Dynamo compatibility * fix: cache slicing error * chore: simplify shared kv cache slicing * chore: vision encoder rename in timm * fix: image processor do_normalize=False * fixup: style * chore: model_doc * fix: docs for code quality * chore: repo consistency * fix: RMSNorm in float as in prior Gemmas * fix: per_layer_inputs = None * chore: Gemma3nForCausalLM from Gemma3nForConditionalGeneration checkpoint * chore: repo consistency * Add initial unit tests for Gemma3nAudioFeatureExtractor (#27) * Add initial unit tests for Gemma3nAudioFeatureExtractor * Add basic unit tests for Gemma3nProcessor (#28) Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> * parameterize tests --------- Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> * chore: code style * fix: test cases * style and consistency * fix config in the test to be coherent with layer cache sharing * fix hidden states in tests and code * inits and mappings * fix modality prefixes * test order and prefixes * fix test exception * fix class order and reduce model size for faster tests * restore _checkpoint_conversion_mapping to load Caual from Conditional * fix config mapping! * fix: reviewer feedback --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: pculliton <phillipculliton@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * fix import test * add model args * auto_docstring * replace test path * consistency * skip tests for now * fix docstring for doc builder * skip unused attr --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: pculliton <phillipculliton@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Arthur <arthur.zucker@gmail.com>
887 lines
36 KiB
Python
887 lines
36 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Gemma3n model."""
|
|
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from datasets import load_dataset
|
|
from parameterized import parameterized
|
|
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoProcessor,
|
|
AutoTokenizer,
|
|
Gemma3nAudioConfig,
|
|
Gemma3nAudioFeatureExtractor,
|
|
Gemma3nConfig,
|
|
Gemma3nTextConfig,
|
|
GenerationConfig,
|
|
is_torch_available,
|
|
)
|
|
from transformers.testing_utils import (
|
|
cleanup,
|
|
require_flash_attn,
|
|
require_read_token,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
|
from ..gemma.test_modeling_gemma import GemmaModelTester
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
Gemma3nAudioEncoder,
|
|
Gemma3nForCausalLM,
|
|
Gemma3nForConditionalGeneration,
|
|
Gemma3nModel,
|
|
Gemma3nTextModel,
|
|
)
|
|
|
|
|
|
class Gemma3nAudioModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=2,
|
|
num_channels=32, # feature_size / input_feat_size
|
|
sampling_rate=16_000,
|
|
raw_audio_length=8_000,
|
|
is_training=True,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.num_channels = num_channels
|
|
self.sampling_rate = sampling_rate
|
|
self.raw_audio_length = raw_audio_length
|
|
self.is_training = is_training
|
|
|
|
def get_feature_extractor_config(self):
|
|
return {
|
|
"feature_size": self.num_channels,
|
|
"sampling_rate": self.sampling_rate,
|
|
"padding_value": 0.0,
|
|
"return_attention_mask": True,
|
|
"frame_length_ms": 32.0,
|
|
"hop_length_ms": 10.0,
|
|
"dither": 0.0, # Important for determinism
|
|
}
|
|
|
|
def get_audio_encoder_config(self):
|
|
return Gemma3nAudioConfig(
|
|
input_feat_size=self.num_channels,
|
|
hidden_size=32,
|
|
conf_num_attention_heads=4,
|
|
conf_num_hidden_layers=2,
|
|
sscp_conv_channel_size=(16, 8),
|
|
conf_conv_kernel_size=3,
|
|
conf_attention_chunk_size=4,
|
|
conf_attention_context_left=5,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
# Prepare inputs for the audio encoder
|
|
feature_extractor_config = self.get_feature_extractor_config()
|
|
audio_encoder_config = self.get_audio_encoder_config()
|
|
|
|
np.random.seed(0)
|
|
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.raw_audio_length)).astype(np.float32)
|
|
raw_speech_2 = np.random.randn(self.raw_audio_length // 2).astype(np.float32)
|
|
raw_speech = [raw_speech_1, raw_speech_2]
|
|
|
|
feature_extractor = Gemma3nAudioFeatureExtractor(**feature_extractor_config)
|
|
audio_inputs = feature_extractor(raw_speech, return_tensors="pt")
|
|
|
|
input_features = audio_inputs["input_features"]
|
|
# The encoder expects a padding mask (True for padding), while the feature extractor
|
|
# returns an attention mask (True for valid tokens). We must invert it.
|
|
input_features_mask = ~audio_inputs["input_features_mask"].to(torch.bool)
|
|
|
|
inputs_dict = {
|
|
"audio_mel": input_features,
|
|
"audio_mel_mask": input_features_mask,
|
|
}
|
|
return audio_encoder_config, inputs_dict
|
|
|
|
|
|
@unittest.skip("Skipped for now!")
|
|
@require_torch
|
|
class Gemma3nAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (Gemma3nAudioEncoder,) if is_torch_available() else ()
|
|
test_pruning = False
|
|
test_head_masking = False
|
|
test_missing_keys = False
|
|
is_generative = False
|
|
_is_stateful = True
|
|
main_input_name = "audio_mel"
|
|
test_initialization = False
|
|
test_can_init_all_missing_weights = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = Gemma3nAudioModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=Gemma3nAudioConfig, hidden_size=37)
|
|
torch.manual_seed(0)
|
|
|
|
# The following values are golden outputs from a deterministic run of the components.
|
|
# They are used to ensure that changes to the code do not alter the numerical output.
|
|
# Generated with seeds np.random.seed(0) and torch.manual_seed(0).
|
|
self.expected_input_features_shape = (2, 48, 32)
|
|
self.expected_input_features_slice = np.array([-5.733152, -5.337127, -4.916284, -4.378989, -3.7622747])
|
|
self.expected_input_features_mask_shape = (2, 48)
|
|
self.expected_input_features_mask_slice = np.array([True, True, True, True, False])
|
|
|
|
self.expected_encoder_output_shape = (2, 3, 32)
|
|
self.expected_encoder_output_slice = torch.tensor([-0.4159, 0.6459, 0.6305, 2.2902, 0.9683])
|
|
self.expected_encoder_mask_shape = (2, 3)
|
|
self.expected_encoder_mask_slice = torch.tensor([False, False, True])
|
|
|
|
# Prepare a shared feature extractor and raw audio for the tests
|
|
self.feature_extractor = Gemma3nAudioFeatureExtractor(**self.model_tester.get_feature_extractor_config())
|
|
np.random.seed(0)
|
|
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.model_tester.raw_audio_length)).astype(
|
|
np.float32
|
|
)
|
|
raw_speech_2 = np.random.randn(self.model_tester.raw_audio_length // 2).astype(np.float32)
|
|
self.raw_speech = [raw_speech_1, raw_speech_2]
|
|
|
|
@unittest.skip("Audio encoder does not support attention output")
|
|
def test_attention_outputs(self):
|
|
pass
|
|
|
|
@unittest.skip("Audio encoder does not support hidden state output")
|
|
def test_hidden_states_output(self):
|
|
pass
|
|
|
|
@unittest.skip("Audio encoder returns a tuple, not a ModelOutput object, skipping equivalence test.")
|
|
def test_model_outputs_equivalence(self):
|
|
pass
|
|
|
|
@unittest.skip("Audio encoder does not support retaining gradients on hidden states/attentions.")
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
pass
|
|
|
|
@unittest.skip("Audio encoder does not have a concept of token embeddings")
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip("Audio encoder does not have a concept of token embeddings")
|
|
def test_resize_tokens_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip("This model has a complex downsampling scheme that is hard to test with the generic batching test.")
|
|
def test_batching_equivalence(self):
|
|
pass
|
|
|
|
def test_feature_extractor(self):
|
|
"""
|
|
Tests the feature extractor's output against pre-computed golden values.
|
|
This ensures the NumPy-based audio preprocessing is correct and consistent.
|
|
"""
|
|
audio_inputs = self.feature_extractor(
|
|
self.raw_speech, padding="longest", pad_to_multiple_of=128, return_tensors="np"
|
|
)
|
|
|
|
input_features = audio_inputs["input_features"]
|
|
self.assertEqual(input_features.shape, self.expected_input_features_shape)
|
|
np.testing.assert_allclose(input_features[0, 0, :5], self.expected_input_features_slice, rtol=1e-5, atol=1e-5)
|
|
|
|
print(input_features[0, 0, :5])
|
|
|
|
input_features_mask = audio_inputs["input_features_mask"]
|
|
self.assertEqual(input_features_mask.shape, self.expected_input_features_mask_shape)
|
|
# The second audio sample is shorter (22 frames vs 48), so its mask should become False at index 22
|
|
np.testing.assert_array_equal(input_features_mask[1, 21:26], self.expected_input_features_mask_slice)
|
|
|
|
def test_audio_encoder(self):
|
|
"""
|
|
Tests the audio encoder's forward pass against pre-computed golden values.
|
|
This ensures the PyTorch-based audio encoding model is correct and consistent.
|
|
"""
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = Gemma3nAudioEncoder(config).to(torch_device).eval()
|
|
|
|
with torch.no_grad():
|
|
encoder_output, encoder_mask = model(**inputs_dict)
|
|
|
|
print(encoder_output[0, 0, :5])
|
|
|
|
# Check output encodings
|
|
self.assertEqual(encoder_output.shape, self.expected_encoder_output_shape)
|
|
torch.testing.assert_close(
|
|
encoder_output[0, 0, :5], self.expected_encoder_output_slice.to(torch_device), rtol=1e-4, atol=1e-4
|
|
)
|
|
|
|
# Check output mask (True means padded)
|
|
# Second sample has 22 feature frames. After downsampling by 4 (conv) -> 5 frames. After downsampling by 4 (reduction) -> 1 frame.
|
|
# So the mask should be [False, True, True]
|
|
self.assertEqual(encoder_mask.shape, self.expected_encoder_mask_shape)
|
|
torch.testing.assert_close(encoder_mask[1, :], self.expected_encoder_mask_slice.to(torch_device))
|
|
|
|
|
|
class Gemma3nTextModelTester(GemmaModelTester):
|
|
activation_sparsity_pattern = None
|
|
forced_config_args = ["activation_sparsity_pattern"]
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
vocab_size_per_layer_input=99,
|
|
hidden_size=16,
|
|
num_hidden_layers=4, # override to correctly test sharing cache pattern
|
|
num_kv_shared_layers=2, # important to override
|
|
layer_types=[
|
|
"full_attention",
|
|
"sliding_attention",
|
|
"full_attention",
|
|
"sliding_attention",
|
|
], # similarly we want to test sharing on both types
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
altup_num_inputs=2,
|
|
intermediate_size=21,
|
|
hidden_activation="gelu_pytorch_tanh",
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
pad_token_id=0,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
is_decoder=False,
|
|
):
|
|
self._verify_model_attributes()
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.vocab_size_per_layer_input = vocab_size_per_layer_input
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_kv_shared_layers = num_kv_shared_layers
|
|
self.layer_types = layer_types
|
|
self.num_attention_heads = num_attention_heads
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.altup_num_inputs = altup_num_inputs
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_activation = hidden_activation
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
self.is_decoder = is_decoder
|
|
|
|
if is_torch_available():
|
|
config_class = Gemma3nTextConfig
|
|
model_class = Gemma3nTextModel
|
|
for_causal_lm_class = Gemma3nForCausalLM
|
|
|
|
|
|
@unittest.skip("Skipped for now!")
|
|
@require_torch
|
|
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
|
|
all_generative_model_classes = (Gemma3nForCausalLM,) if is_torch_available() else ()
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
_is_stateful = True
|
|
model_split_percents = [0.5, 0.6]
|
|
|
|
def setUp(self):
|
|
self.model_tester = Gemma3nTextModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self,
|
|
config_class=Gemma3nConfig,
|
|
hidden_size=37,
|
|
text_config={"activation_sparsity_pattern": None},
|
|
)
|
|
|
|
def _check_hidden_states_for_generate(
|
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
|
):
|
|
"Gemma3n has special hidden states shape with 1 additional dim (which is then reduced with projections)"
|
|
|
|
self.assertIsInstance(hidden_states, tuple)
|
|
self.assertListEqual(
|
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
[True] * len(hidden_states),
|
|
)
|
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
|
|
|
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
|
# new token(s)
|
|
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
|
# elaborate checks
|
|
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
|
if use_cache and generated_length > 0:
|
|
model_input_length = 1
|
|
else:
|
|
model_input_length = prompt_length + generated_length
|
|
expected_shape = (config.altup_num_inputs, batch_size, model_input_length, config.hidden_size)
|
|
# check hidden size
|
|
self.assertListEqual(
|
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
[expected_shape] * len(iter_hidden_states),
|
|
)
|
|
|
|
|
|
class Gemma3nVision2TextModelTester:
|
|
text_config = {"activation_sparsity_pattern": None}
|
|
forced_config_args = ["text_config"]
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
mm_tokens_per_image=2,
|
|
image_token_index=1,
|
|
boi_token_index=2,
|
|
eoi_token_index=3,
|
|
seq_length=25,
|
|
is_training=True,
|
|
vision_config={
|
|
"use_labels": True,
|
|
"image_size": 20,
|
|
"patch_size": 5,
|
|
"num_channels": 3,
|
|
"is_training": True,
|
|
"hidden_size": 32,
|
|
"num_key_value_heads": 1,
|
|
"num_hidden_layers": 2,
|
|
"num_attention_heads": 4,
|
|
"intermediate_size": 37,
|
|
"dropout": 0.1,
|
|
"attention_dropout": 0.1,
|
|
"initializer_range": 0.02,
|
|
},
|
|
use_cache=False,
|
|
):
|
|
self.parent = parent
|
|
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
|
|
self.mm_tokens_per_image = mm_tokens_per_image
|
|
self.image_token_index = image_token_index
|
|
self.boi_token_index = boi_token_index
|
|
self.eoi_token_index = eoi_token_index
|
|
self.llm_tester = Gemma3nTextModelTester(self.parent)
|
|
self.text_config = self.llm_tester.get_config()
|
|
self.vision_config = vision_config
|
|
self.seq_length = seq_length
|
|
self.pad_token_id = self.text_config.pad_token_id
|
|
|
|
self.num_hidden_layers = self.text_config.num_hidden_layers
|
|
self.vocab_size = self.text_config.vocab_size
|
|
self.hidden_size = self.text_config.hidden_size
|
|
self.num_attention_heads = self.text_config.num_attention_heads
|
|
self.is_training = is_training
|
|
|
|
self.batch_size = 3
|
|
self.num_channels = vision_config["num_channels"]
|
|
self.image_size = vision_config["image_size"]
|
|
self.encoder_seq_length = seq_length
|
|
self.use_cache = use_cache
|
|
|
|
def get_config(self):
|
|
return Gemma3nConfig(
|
|
text_config=self.text_config,
|
|
vision_config=self.vision_config,
|
|
image_token_index=self.image_token_index,
|
|
boi_token_index=self.boi_token_index,
|
|
eoi_token_index=self.eoi_token_index,
|
|
mm_tokens_per_image=self.mm_tokens_per_image,
|
|
)
|
|
|
|
def prepare_config_and_inputs(self):
|
|
pixel_values = floats_tensor(
|
|
[
|
|
self.batch_size,
|
|
self.vision_config["num_channels"],
|
|
self.vision_config["image_size"],
|
|
self.vision_config["image_size"],
|
|
]
|
|
)
|
|
config = self.get_config()
|
|
|
|
return config, pixel_values
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, pixel_values = config_and_inputs
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
|
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
|
|
|
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
|
|
# do not change this unless you modified image size or patch size
|
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
|
input_ids[:, :1] = config.image_token_index
|
|
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
token_type_ids[input_ids == config.image_token_index] = 1
|
|
|
|
inputs_dict = {
|
|
"pixel_values": pixel_values,
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"token_type_ids": token_type_ids,
|
|
}
|
|
return config, inputs_dict
|
|
|
|
|
|
@unittest.skip("Skipped for now!")
|
|
@require_torch
|
|
class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
all_model_classes = (Gemma3nModel, Gemma3nForConditionalGeneration) if is_torch_available() else ()
|
|
all_generative_model_classes = (Gemma3nForConditionalGeneration,) if is_torch_available() else ()
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
test_missing_keys = False
|
|
_is_stateful = True
|
|
model_split_percents = [0.5, 0.6]
|
|
|
|
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
|
|
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
|
# in the dispatch_model function
|
|
test_cpu_offload = False
|
|
test_disk_offload_safetensors = False
|
|
test_disk_offload_bin = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = Gemma3nVision2TextModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self,
|
|
config_class=Gemma3nConfig,
|
|
hidden_size=37,
|
|
text_config={"activation_sparsity_pattern": None},
|
|
)
|
|
|
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
|
|
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
|
|
)
|
|
def test_multi_gpu_data_parallel_forward(self):
|
|
pass
|
|
|
|
@unittest.skip("Failing because of unique cache (HybridCache)")
|
|
def test_model_outputs_equivalence(self, **kwargs):
|
|
pass
|
|
|
|
@parameterized.expand([("random",), ("same",)])
|
|
@pytest.mark.generate
|
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
|
def test_assisted_decoding_sample(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding")
|
|
def test_dola_decoding_sample(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv")
|
|
def test_generate_continue_from_past_key_values(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation")
|
|
def test_beam_search_low_memory(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
|
def test_contrastive_generate(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
|
def test_contrastive_generate_low_memory(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
|
def test_generate_with_static_cache(self):
|
|
pass
|
|
|
|
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
|
|
)
|
|
def test_initialization(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
|
)
|
|
def test_flex_attention_with_grads(self):
|
|
pass
|
|
|
|
def test_automodelforcausallm(self):
|
|
"""
|
|
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3n config, i.e. that
|
|
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
|
"""
|
|
config = self.model_tester.get_config()
|
|
model = Gemma3nForConditionalGeneration(config)
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
model.save_pretrained(tmp_dir)
|
|
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
|
self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM)
|
|
|
|
|
|
@unittest.skip("Skipped for now!")
|
|
@slow
|
|
@require_torch_gpu
|
|
@require_read_token
|
|
class Gemma3nIntegrationTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.processor = AutoProcessor.from_pretrained("Google/gemma-3n-E4B-it", padding_side="left")
|
|
|
|
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
|
self.messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "url": url},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
|
|
audio_ds = load_dataset(
|
|
"etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav"
|
|
)
|
|
self.audio_file_path = audio_ds["train"][0]["audio"]["path"]
|
|
|
|
def tearDown(self):
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
def test_model_4b_bf16(self):
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
|
).to(torch_device)
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
self.messages,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
add_generation_prompt=True,
|
|
).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
def test_model_with_audio(self):
|
|
"""
|
|
Tests the full model pipeline with batched audio inputs provided as file paths.
|
|
This ensures the processor correctly loads and processes audio files.
|
|
"""
|
|
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
|
).to(torch_device)
|
|
|
|
messages = [
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Transcribe the following speech segment in English:"},
|
|
{"type": "audio", "audio": str(self.audio_file_path)},
|
|
],
|
|
}
|
|
],
|
|
]
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to(torch_device, dtype=model.dtype)
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
|
output = model.generate(**inputs, max_new_tokens=16, do_sample=False)
|
|
output = output[:, input_len:]
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = ["Chapter 1. Mrs. Rachel Lind is surprised.\n\nMrs. Rachel Lind"]
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
def test_model_4b_batch(self):
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
|
).to(torch_device)
|
|
|
|
messages_2 = [
|
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
|
},
|
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
|
{"type": "text", "text": "Are these images identical?"},
|
|
],
|
|
},
|
|
]
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
[self.messages, messages_2],
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
add_generation_prompt=True,
|
|
).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = [
|
|
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
|
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
|
] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
def test_model_4b_crops(self):
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
|
).to(torch_device)
|
|
|
|
crop_config = {
|
|
"images_kwargs": {
|
|
"do_pan_and_scan": True,
|
|
"pan_and_scan_max_num_crops": 448,
|
|
"pan_and_scan_min_crop_size": 32,
|
|
"pan_and_scan_min_ratio_to_activate": 0.3,
|
|
}
|
|
}
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
self.messages,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
add_generation_prompt=True,
|
|
**crop_config,
|
|
).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
|
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
|
|
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
def test_model_4b_multiimage(self):
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
|
).to(torch_device)
|
|
|
|
messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
|
{"type": "text", "text": "What do you see here?"},
|
|
],
|
|
},
|
|
]
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
add_generation_prompt=True,
|
|
).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
def test_model_1b_text_only(self):
|
|
model_id = "google/gemma-3-1b-it"
|
|
|
|
model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
|
torch_device
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
|
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
# TODO: raushan FA2 generates gibberish for no reason, check later
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
def test_model_4b_flash_attn(self):
|
|
model_id = "Google/gemma-3n-E4B-it"
|
|
|
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
|
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
|
).to(torch_device)
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
self.messages,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
add_generation_prompt=True,
|
|
).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
|
|
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
|
|
|
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
|
|
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
|
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
|
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
|
Outputs for every attention functions should be coherent and identical.
|
|
"""
|
|
model_id = "google/gemma-3-1b-it"
|
|
|
|
input_text = [
|
|
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
|
"A list of colors: red, blue", # This will almost all be padding tokens
|
|
]
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
|
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
|
).to(torch_device)
|
|
|
|
# Make sure prefill is larger than sliding window
|
|
input_size = inputs.input_ids.shape[-1]
|
|
self.assertTrue(input_size > model.config.sliding_window)
|
|
|
|
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
|
output_text = tokenizer.batch_decode(out)
|
|
|
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
|
|
|
def test_generation_beyond_sliding_window_with_generation_config(self):
|
|
"""
|
|
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
|
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
|
"""
|
|
model_id = "google/gemma-3-1b-it"
|
|
attn_implementation = "sdpa"
|
|
|
|
input_text = [
|
|
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
|
"A list of colors: red, blue", # This will almost all be padding tokens
|
|
]
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
|
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
|
).to(torch_device)
|
|
|
|
# Make sure prefill is larger than sliding window
|
|
input_size = inputs.input_ids.shape[-1]
|
|
self.assertTrue(input_size > model.config.sliding_window)
|
|
|
|
generation_config = GenerationConfig(max_new_tokens=20)
|
|
|
|
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
|
output_text = tokenizer.batch_decode(out)
|
|
|
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|