mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix generate with inputs_embeds
as input (#32493)
* I think inputs_embeds has ndim == 3 * fix sequence length catch * add generate test * [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama * skip whisper * fix bart test * more fixes
This commit is contained in:
parent
b01f9c484c
commit
044281605f
@ -756,17 +756,18 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1132,17 +1132,18 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1403,17 +1403,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1270,17 +1270,18 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1143,17 +1143,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@ -301,7 +300,6 @@ class Gemma2Attention(nn.Module):
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@ -501,11 +499,9 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
@ -516,7 +512,6 @@ class Gemma2SdpaAttention(Gemma2Attention):
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -581,7 +576,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -1013,7 +1007,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -1080,7 +1073,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@ -1096,22 +1088,20 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)}
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
@ -1122,7 +1112,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
|
@ -970,17 +970,18 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1220,17 +1220,18 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.embed_out.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1100,17 +1100,18 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1265,17 +1265,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1136,17 +1136,18 @@ class NemotronForCausalLM(NemotronPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1176,17 +1176,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -993,17 +993,18 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1278,17 +1278,18 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1318,17 +1318,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1176,17 +1176,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1372,17 +1372,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1271,17 +1271,18 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1152,17 +1152,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if inputs_embeds is not None:
|
||||
batch_size, sequence_length = inputs_embeds.shape
|
||||
device = inputs_embeds.device
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
device = input_ids.device
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.lm_head.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
@ -1540,3 +1540,8 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
|
||||
@unittest.skip
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# generate only works with input ids for bartforcausalLM
|
||||
pass
|
||||
|
@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# generate only works with input ids for bertforcausalLM
|
||||
pass
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
|
@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Decoder can't keep attention grads")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
@ -2819,6 +2819,53 @@ class ModelTesterMixin:
|
||||
)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model_forward_args = inspect.signature(model.forward).parameters
|
||||
if "inputs_embeds" not in model_forward_args:
|
||||
self.skipTest(reason="This model doesn't use `inputs_embeds`")
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
# some models infer position ids/attn mask differently when input ids
|
||||
# by check if pad_token let's make sure no padding is in input ids
|
||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||
del inputs["input_ids"]
|
||||
inputs_embeds = wte(input_ids)
|
||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
|
||||
out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
inputs_embeds = wte(encoder_input_ids)
|
||||
decoder_inputs_embeds = wte(decoder_input_ids)
|
||||
out_ids = model.generate(
|
||||
input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
|
||||
)[:, -2:]
|
||||
out_embeds = model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
**inputs,
|
||||
max_new_tokens=2,
|
||||
)
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user