mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
FIX [Gemma
] Fix bad rebase with transformers main (#29170)
fix bad rebase
This commit is contained in:
parent
594c1277b2
commit
ae49b218c3
@ -124,7 +124,7 @@ def rotate_half(x):
|
|||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|||||||
k (`torch.Tensor`): The key tensor.
|
k (`torch.Tensor`): The key tensor.
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
position_ids (`torch.Tensor`):
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
Deprecated and unused.
|
||||||
used to pass offsetted position ids when working with a KV-cache.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
@ -940,6 +939,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
def _update_causal_mask(self, attention_mask, input_tensor):
|
def _update_causal_mask(self, attention_mask, input_tensor):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
@ -955,16 +958,8 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||||
|
|
||||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
# We use the current dtype to avoid any overflows
|
||||||
causal_mask = (
|
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mask = torch.full(
|
|
||||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
|
||||||
fill_value=torch.finfo(dtype).min,
|
|
||||||
)
|
|
||||||
causal_mask = torch.triu(mask, diagonal=1)
|
|
||||||
|
|
||||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||||
if attention_mask is not None and attention_mask.dim() == 2:
|
if attention_mask is not None and attention_mask.dim() == 2:
|
||||||
@ -1146,29 +1141,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
|
||||||
# generation with static cache
|
# generation with static cache
|
||||||
past_length = past_key_value.get_seq_length()
|
cache_position = kwargs.get("cache_position", None)
|
||||||
|
if cache_position is None:
|
||||||
|
past_length = 0
|
||||||
|
else:
|
||||||
|
past_length = cache_position[-1] + 1
|
||||||
input_ids = input_ids[:, past_length:]
|
input_ids = input_ids[:, past_length:]
|
||||||
position_ids = position_ids[:, past_length:]
|
position_ids = position_ids[:, past_length:]
|
||||||
|
|
||||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||||
# same goes for position ids. Could also help with continued generation.
|
# same goes for position ids. Could also help with continued generation.
|
||||||
cache_position = kwargs.get("cache_position", None)
|
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
||||||
if cache_position is None:
|
|
||||||
cache_position = torch.arange(
|
|
||||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
else:
|
else:
|
||||||
model_inputs = {"input_ids": input_ids}
|
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||||
|
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||||
|
# TODO: use `next_tokens` directly instead.
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids.contiguous(),
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
Loading…
Reference in New Issue
Block a user