mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
language update
This commit is contained in:
parent
210d407245
commit
5ff0c60505
@ -237,17 +237,17 @@ class Attention(nn.Module):
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, x, past=None):
|
||||
def forward(self, x, layer_past=None):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
present = key, value
|
||||
if past is not None:
|
||||
past_key, past_value = past
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = torch.stack((key, value))
|
||||
a = self._attn(query, key, value)
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
@ -277,8 +277,8 @@ class Block(nn.Module):
|
||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, past=None):
|
||||
a, present = self.attn(self.ln_1(x), past=past)
|
||||
def forward(self, x, layer_past=None):
|
||||
a, present = self.attn(self.ln_1(x), layer_past=past)
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
@ -346,7 +346,7 @@ class GPT2PreTrainedModel(nn.Module):
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def set_tied():
|
||||
def set_tied(self):
|
||||
pass
|
||||
|
||||
def init_weights(self, module):
|
||||
@ -526,12 +526,12 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, pasts=None):
|
||||
if pasts is None:
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
||||
if past is None:
|
||||
past_length = 0
|
||||
pasts = [None] * len(self.h)
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
pasts[0][0].size(-2)
|
||||
past[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
@ -549,8 +549,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
token_type_embeds = 0
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
presents = []
|
||||
for block, past in zip(self.h, pasts):
|
||||
hidden_states, present = block(hidden_states, past)
|
||||
for block, layer_past in zip(self.h, past):
|
||||
hidden_states, present = block(hidden_states, layer_past)
|
||||
presents.append(present)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
@ -607,8 +607,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
"""
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, pasts=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
@ -673,8 +673,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
"""
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, pasts=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts)
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
|
Loading…
Reference in New Issue
Block a user