Use more self-explanatory variables

This commit is contained in:
calpt 2023-08-04 20:44:54 +02:00
parent a6d3c212ec
commit 5fd9ab3911
4 changed files with 32 additions and 36 deletions

View File

@ -530,21 +530,20 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
seq_length = input_ids.shape[-1]
input_ids = input_ids.view(-1, seq_length)
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
batch_size, seq_length = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_ids = token_type_ids.view(-1, seq_length)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
position_ids = position_ids.view(-1, seq_length)
if past_key_values is None:
past_length = 0
@ -552,8 +551,8 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# GPT2Attention mask.
if attention_mask is not None:
@ -603,7 +602,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
output_shape = (-1, seq_length, hidden_states.size(-1))
if self.gradient_checkpointing and self.training:
if use_cache:

View File

@ -776,21 +776,20 @@ class GPT2Model(GPT2PreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
seq_length = input_ids.shape[-1]
input_ids = input_ids.view(-1, seq_length)
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
batch_size, seq_length = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_ids = token_type_ids.view(-1, seq_length)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
position_ids = position_ids.view(-1, seq_length)
if past_key_values is None:
past_length = 0
@ -798,8 +797,8 @@ class GPT2Model(GPT2PreTrainedModel):
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# GPT2Attention mask.
if attention_mask is not None:
@ -849,7 +848,7 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
output_shape = (-1, seq_length, hidden_states.size(-1))
if self.gradient_checkpointing and self.training:
if use_cache:

View File

@ -525,21 +525,20 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
seq_length = input_ids.shape[-1]
input_ids = input_ids.view(-1, seq_length)
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
batch_size, seq_length = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_ids = token_type_ids.view(-1, seq_length)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
position_ids = position_ids.view(-1, seq_length)
if past_key_values is None:
past_length = 0
@ -548,8 +547,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# Attention mask.
if attention_mask is not None:
@ -588,7 +587,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
output_shape = (-1, seq_length, hidden_states.size(-1))
if self.gradient_checkpointing and self.training:
if use_cache:

View File

@ -577,22 +577,21 @@ class GPTJModel(GPTJPreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
seq_length = input_ids.shape[-1]
input_ids = input_ids.view(-1, seq_length)
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
batch_size, seq_length = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_ids = token_type_ids.view(-1, seq_length)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
position_ids = position_ids.view(-1, seq_length).long()
if past_key_values is None:
past_length = 0
@ -601,8 +600,8 @@ class GPTJModel(GPTJPreTrainedModel):
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# Attention mask.
if attention_mask is not None:
@ -641,7 +640,7 @@ class GPTJModel(GPTJPreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
output_shape = (-1, seq_length, hidden_states.size(-1))
if self.gradient_checkpointing and self.training:
if use_cache: