Remove unnecessary views of position_ids (#26059)

* Remove unnecessary `view` of `position_ids` in `modeling_llama`

When `position_ids` is `None`, its value is generated using
`torch.arange`, which creates a tensor of size `(seq_length +
past_key_values_length) - past_key_values_length = seq_length`. The
tensor is then unsqueezed, resulting in a tensor of shape `(1,
seq_length)`. This means that the last `view` to a tensor of shape
`(-1, seq_length)` is a no-op.

This commit removes the unnecessary view.

* Remove no-op `view` of `position_ids` in rest of transformer models
This commit is contained in:
Ramiro Leal-Cavazos 2023-10-06 08:28:00 +00:00 committed by GitHub
parent 75a33d60f2
commit 8878eb1bd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 15 additions and 46 deletions

View File

@ -475,9 +475,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
@ -486,7 +483,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:

View File

@ -416,7 +416,7 @@ class CTRLModel(CTRLPreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:
@ -447,7 +447,6 @@ class CTRLModel(CTRLPreTrainedModel):
token_type_embeds *= np.sqrt(self.d_model_size)
else:
token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1])
if inputs_embeds is None:
inputs_embeds = self.w(input_ids)

View File

@ -544,8 +544,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
@ -554,7 +552,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# GPT2Attention mask.
if attention_mask is not None:

View File

@ -630,9 +630,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

View File

@ -1128,9 +1128,7 @@ class FalconModel(FalconPreTrainedModel):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask(
attention_mask,

View File

@ -790,8 +790,6 @@ class GPT2Model(GPT2PreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
@ -800,7 +798,7 @@ class GPT2Model(GPT2PreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# GPT2Attention mask.
if attention_mask is not None:

View File

@ -577,8 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
@ -594,7 +592,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
elif position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# Self-attention mask.
query_length = input_shape[-1]

View File

@ -539,8 +539,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
@ -550,7 +548,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:

View File

@ -596,9 +596,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:

View File

@ -592,9 +592,6 @@ class GPTJModel(GPTJPreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
@ -603,7 +600,7 @@ class GPTJModel(GPTJPreTrainedModel):
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# Attention mask.
if attention_mask is not None:

View File

@ -1208,9 +1208,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
no_images = False
if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:

View File

@ -729,8 +729,6 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
@ -739,7 +737,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
# ImageGPTAttention mask.
if attention_mask is not None:

View File

@ -867,9 +867,7 @@ class LlamaModel(LlamaPreTrainedModel):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

View File

@ -638,9 +638,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

View File

@ -623,9 +623,7 @@ class XGLMModel(XGLMPreTrainedModel):
dtype=torch.long,
device=input_ids.device if input_ids is not None else inputs_embeds.device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
else:
position_ids = position_ids.view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale