language update

This commit is contained in:
thomwolf 2019-02-18 00:55:47 +01:00
parent 210d407245
commit 5ff0c60505

View File

@ -237,17 +237,17 @@ class Attention(nn.Module):
else:
return x.permute(0, 2, 1, 3)
def forward(self, x, past=None):
def forward(self, x, layer_past=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
present = key, value
if past is not None:
past_key, past_value = past
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key, value))
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
@ -277,8 +277,8 @@ class Block(nn.Module):
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, past=None):
a, present = self.attn(self.ln_1(x), past=past)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
@ -346,7 +346,7 @@ class GPT2PreTrainedModel(nn.Module):
)
self.config = config
def set_tied():
def set_tied(self):
pass
def init_weights(self, module):
@ -526,12 +526,12 @@ class GPT2Model(GPT2PreTrainedModel):
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, pasts=None):
if pasts is None:
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None:
past_length = 0
pasts = [None] * len(self.h)
past = [None] * len(self.h)
else:
pasts[0][0].size(-2)
past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
@ -549,8 +549,8 @@ class GPT2Model(GPT2PreTrainedModel):
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = []
for block, past in zip(self.h, pasts):
hidden_states, present = block(hidden_states, past)
for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@ -607,8 +607,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
"""
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, pasts=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
@ -673,8 +673,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, pasts=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, pasts)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []