mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix position embeddings for GPT-J and CodeGen (#22069)
* Revert "[GPT-J] add deprecation warning (#21869)"
This reverts commit fb76994c41
.
* Fix position embeddings for GPT-J and CodeGen
* Address review comments from @gante
* Fix "Copied from" comment referencing wrong function
* Fix copy/paste mistake
* Fix training path
* Hopefully make torch.fx happy
* Move position_ids long cast
* Revert "Hopefully make torch.fx happy"
This reverts commit e41a6f4cad3ff441124c7457b19cfb630d4ca025.
* Changes to help with torch.fx tracing
* Linter fix
* Correct position_ids tensor type hint
* Work-around torch.fx tracing issue
* Get the changes to work with torch.fx
* Address review comment from @michaelbenayoun
* Another small adjustment
* Add explanatory comment; small code tidyup
This commit is contained in:
parent
8e6c34b390
commit
4e94c6c008
@ -51,43 +51,26 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.fixed_pos_embedding
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
if seq_len is None:
|
||||
seq_len = x.shape[seq_dim]
|
||||
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
|
||||
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
||||
sinusoid_inp = (
|
||||
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
|
||||
)
|
||||
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
|
||||
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
||||
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
|
||||
def rotate_every_two(x):
|
||||
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[:, :, :, ::2]
|
||||
x2 = x[:, :, :, 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.duplicate_interleave
|
||||
def duplicate_interleave(m):
|
||||
"""
|
||||
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
||||
"""
|
||||
dim0 = m.shape[0]
|
||||
m = m.view(-1, 1) # flatten the matrix
|
||||
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
||||
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
||||
return m
|
||||
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(x, sincos, offset=0):
|
||||
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
|
||||
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
||||
return (x * cos) + (rotate_every_two(x) * sin)
|
||||
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
||||
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
||||
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
||||
return (tensor * cos) + (rotate_every_two(tensor) * sin)
|
||||
|
||||
|
||||
class CodeGenAttention(nn.Module):
|
||||
@ -117,9 +100,9 @@ class CodeGenAttention(nn.Module):
|
||||
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
|
||||
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.rotary_dim = None
|
||||
if config.rotary_dim is not None:
|
||||
self.rotary_dim = config.rotary_dim
|
||||
self.rotary_dim = config.rotary_dim
|
||||
pos_embd_dim = self.rotary_dim or self.embed_dim
|
||||
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
|
||||
|
||||
def _split_heads(self, x, n_head, dim_head, mp_num):
|
||||
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
|
||||
@ -183,8 +166,9 @@ class CodeGenAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
@ -205,12 +189,13 @@ class CodeGenAttention(nn.Module):
|
||||
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
seq_len = key.shape[1]
|
||||
offset = 0
|
||||
embed_positions = self.embed_positions
|
||||
if embed_positions.device != position_ids.device:
|
||||
embed_positions = embed_positions.to(position_ids.device)
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
if layer_past is not None:
|
||||
offset = layer_past[0].shape[-2]
|
||||
seq_len += offset
|
||||
sincos = embed_positions[position_ids]
|
||||
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
||||
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
@ -219,16 +204,14 @@ class CodeGenAttention(nn.Module):
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim :]
|
||||
|
||||
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
||||
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
||||
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
||||
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
||||
key = apply_rotary_pos_emb(key, sin, cos)
|
||||
query = apply_rotary_pos_emb(query, sin, cos)
|
||||
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
@ -292,6 +275,7 @@ class CodeGenBlock(nn.Module):
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
@ -299,9 +283,10 @@ class CodeGenBlock(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -488,7 +473,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
position_ids = position_ids.view(-1, input_shape[-1]).long()
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
@ -568,13 +553,15 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -645,8 +632,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
|
@ -30,7 +30,13 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_fx_proxy,
|
||||
logging,
|
||||
)
|
||||
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_gptj import GPTJConfig
|
||||
|
||||
@ -48,39 +54,28 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
if seq_len is None:
|
||||
seq_len = x.shape[seq_dim]
|
||||
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
||||
sinusoid_inp = (
|
||||
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
|
||||
)
|
||||
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
|
||||
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
||||
|
||||
|
||||
def rotate_every_two(x):
|
||||
@torch.fx.wrap
|
||||
def get_embed_positions(embed_positions, position_ids):
|
||||
return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)
|
||||
|
||||
|
||||
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[:, :, :, ::2]
|
||||
x2 = x[:, :, :, 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
||||
|
||||
|
||||
def duplicate_interleave(m):
|
||||
"""
|
||||
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
||||
"""
|
||||
dim0 = m.shape[0]
|
||||
m = m.view(-1, 1) # flatten the matrix
|
||||
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
||||
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
||||
return m
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, sincos, offset=0):
|
||||
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
|
||||
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
||||
return (x * cos) + (rotate_every_two(x) * sin)
|
||||
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
||||
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
||||
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
||||
return (tensor * cos) + (rotate_every_two(tensor) * sin)
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
@ -113,9 +108,9 @@ class GPTJAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.rotary_dim = None
|
||||
if config.rotary_dim is not None:
|
||||
self.rotary_dim = config.rotary_dim
|
||||
self.rotary_dim = config.rotary_dim
|
||||
pos_embd_dim = self.rotary_dim or self.embed_dim
|
||||
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
|
||||
|
||||
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
|
||||
"""
|
||||
@ -187,11 +182,19 @@ class GPTJAttention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _get_embed_positions(self, position_ids):
|
||||
embed_positions = self.embed_positions
|
||||
if embed_positions.device != position_ids.device:
|
||||
embed_positions = embed_positions.to(position_ids.device)
|
||||
self.embed_positions = embed_positions
|
||||
return embed_positions.repeat(position_ids.shape[0], 1, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
@ -207,12 +210,16 @@ class GPTJAttention(nn.Module):
|
||||
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
|
||||
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
|
||||
|
||||
seq_len = key.shape[1]
|
||||
offset = 0
|
||||
if is_torch_fx_proxy(position_ids):
|
||||
# The logic to conditionally copy to GPU could not be traced, so we do this
|
||||
# every time in the torch.fx case
|
||||
embed_positions = get_embed_positions(self.embed_positions, position_ids)
|
||||
else:
|
||||
embed_positions = self._get_embed_positions(position_ids)
|
||||
|
||||
if layer_past is not None:
|
||||
offset = layer_past[0].shape[-2]
|
||||
seq_len += offset
|
||||
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
|
||||
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
|
||||
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
||||
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
@ -221,16 +228,14 @@ class GPTJAttention(nn.Module):
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim :]
|
||||
|
||||
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
||||
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
||||
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
||||
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
||||
key = apply_rotary_pos_emb(key, sin, cos)
|
||||
query = apply_rotary_pos_emb(query, sin, cos)
|
||||
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
@ -292,6 +297,7 @@ class GPTJBlock(nn.Module):
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
@ -299,9 +305,10 @@ class GPTJBlock(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -391,6 +398,11 @@ GPTJ_INPUTS_DOCSTRING = r"""
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -544,24 +556,14 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in GPT-J and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -581,11 +583,23 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1]).long()
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
@ -665,13 +679,15 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -766,9 +782,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
|
||||
):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
@ -777,6 +791,14 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
@ -787,7 +809,8 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
@ -808,6 +831,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -815,7 +839,6 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -823,16 +846,6 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in GPT-J and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
@ -840,6 +853,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -941,6 +955,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -948,7 +963,6 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -956,16 +970,6 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in GPT-J and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
@ -973,6 +977,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1074,6 +1079,7 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
@ -1081,7 +1087,6 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1093,22 +1098,13 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in GPT-J and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -363,6 +363,26 @@ def torch_tensor_repeat(self, *sizes):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_repeat_interleave(*args, dim=None, output_size=None):
|
||||
num_args = len(args)
|
||||
if num_args == 1:
|
||||
shape = [output_size if output_size is not None else args[0].sum()]
|
||||
else:
|
||||
shape = list(args[0].shape)
|
||||
if dim is None:
|
||||
if num_args > 2:
|
||||
dim = args[2]
|
||||
else:
|
||||
shape = [sum(shape)]
|
||||
dim = 0
|
||||
repeats = args[1]
|
||||
if isinstance(repeats, int) or torch.numel(repeats) == 1:
|
||||
shape[dim] *= int(repeats)
|
||||
else:
|
||||
shape[dim] = output_size if output_size is not None else repeats.sum()
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
def torch_index_select(input, dim, index, *, out=None):
|
||||
shape = list(input.shape)
|
||||
shape[dim] = len(index)
|
||||
@ -373,6 +393,16 @@ def torch_tensor_index_select(self, dim, index):
|
||||
return torch_index_select(self, dim, index)
|
||||
|
||||
|
||||
def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
|
||||
shape = list(input.shape)
|
||||
shape[dim] = index.shape[dim]
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
def torch_tensor_gather(self, dim, index):
|
||||
return torch_gather(self, dim, index)
|
||||
|
||||
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return input
|
||||
|
||||
@ -539,11 +569,14 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.Tensor.baddbmm: torch_tensor_baddbmm,
|
||||
torch.einsum: torch_einsum,
|
||||
torch.Tensor.repeat: torch_tensor_repeat,
|
||||
torch.repeat_interleave: torch_repeat_interleave,
|
||||
torch.roll: torch_roll,
|
||||
torch.flip: torch_flip,
|
||||
torch.Tensor.flip: torch_tensor_flip,
|
||||
torch.index_select: torch_index_select,
|
||||
torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.gather: torch_gather,
|
||||
torch.Tensor.gather: torch_tensor_gather,
|
||||
torch.nn.Conv1d: torch_nn_conv1d,
|
||||
torch.nn.Conv2d: torch_nn_conv2d,
|
||||
torch.squeeze: torch_squeeze,
|
||||
|
Loading…
Reference in New Issue
Block a user