mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Gemma2 synced multi-GPU generation (#35232)
* Fix Gemma2 synced multi-GPU generation * Fix import ordering in modular_gemma2.py
This commit is contained in:
parent
fa56dcc2ab
commit
4831a94ee7
@ -41,6 +41,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -936,8 +937,13 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
@ -29,7 +29,7 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ..gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaForCausalLM,
|
||||
@ -668,8 +668,13 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||
# (we can't check exception 3 while compiling)
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if (
|
||||
inputs_embeds is not None # Exception 1
|
||||
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||
):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
Loading…
Reference in New Issue
Block a user