Fix Gemma2 synced multi-GPU generation (#35232)

* Fix Gemma2 synced multi-GPU generation

* Fix import ordering in modular_gemma2.py
This commit is contained in:
ManukyanD 2025-02-05 13:07:50 +04:00 committed by GitHub
parent fa56dcc2ab
commit 4831a94ee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 3 deletions

View File

@ -41,6 +41,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -936,8 +937,13 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

View File

@ -29,7 +29,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import logging
from ...utils import is_torchdynamo_compiling, logging
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaForCausalLM,
@ -668,8 +668,13 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]