FIX [Gemma] Fix bad rebase with transformers main (#29170)

fix bad rebase
This commit is contained in:
Younes Belkada 2024-02-21 14:56:34 +01:00 committed by GitHub
parent 594c1277b2
commit ae49b218c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -124,7 +124,7 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@ -940,6 +939,10 @@ class GemmaModel(GemmaPreTrainedModel):
attentions=all_self_attns,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
@ -955,16 +958,8 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
@ -1146,29 +1141,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),