mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
FIX [Gemma
] Fix bad rebase with transformers main (#29170)
fix bad rebase
This commit is contained in:
parent
594c1277b2
commit
ae49b218c3
@ -124,7 +124,7 @@ def rotate_half(x):
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -940,6 +939,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
def _update_causal_mask(self, attention_mask, input_tensor):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
@ -955,16 +958,8 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
||||
causal_mask = (
|
||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||
fill_value=torch.finfo(dtype).min,
|
||||
)
|
||||
causal_mask = torch.triu(mask, diagonal=1)
|
||||
# We use the current dtype to avoid any overflows
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
@ -1146,29 +1141,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
|
||||
# generation with static cache
|
||||
past_length = past_key_value.get_seq_length()
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
past_length = 0
|
||||
else:
|
||||
past_length = cache_position[-1] + 1
|
||||
input_ids = input_ids[:, past_length:]
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||
# same goes for position ids. Could also help with continued generation.
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
||||
)
|
||||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"position_ids": position_ids.contiguous(),
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
|
Loading…
Reference in New Issue
Block a user