mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
TorchScript flag in config; Tied weights when not running TorchScript; tuple concatenation clean-up.
This commit is contained in:
parent
4703148f0c
commit
b43b130f35
@ -46,6 +46,7 @@ class PretrainedConfig(object):
|
|||||||
self.num_labels = kwargs.pop('num_labels', 2)
|
self.num_labels = kwargs.pop('num_labels', 2)
|
||||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||||
|
self.torchscript = kwargs.pop('torchscript', False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
|
@ -428,23 +428,23 @@ class BertEncoder(nn.Module):
|
|||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
all_attentions += (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
outputs += (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
outputs += (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # outputs, (hidden states), (attentions)
|
return outputs # outputs, (hidden states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
|
|||||||
def __init__(self, config, bert_model_embedding_weights):
|
def __init__(self, config, bert_model_embedding_weights):
|
||||||
super(BertLMPredictionHead, self).__init__()
|
super(BertLMPredictionHead, self).__init__()
|
||||||
self.transform = BertPredictionHeadTransform(config)
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
self.torchscript = config.torchscript
|
||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
# The output weights are the same as the input embeddings, but there is
|
||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||||
bert_model_embedding_weights.size(0),
|
bert_model_embedding_weights.size(0),
|
||||||
bias=False)
|
bias=False)
|
||||||
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
|
||||||
|
if self.torchscript:
|
||||||
|
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
||||||
|
else:
|
||||||
|
self.decoder.weight = bert_model_embedding_weights
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
@ -322,6 +322,7 @@ class GPT2LMHead(nn.Module):
|
|||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.predict_special_tokens = config.predict_special_tokens
|
self.predict_special_tokens = config.predict_special_tokens
|
||||||
|
self.torchscript = config.torchscript
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
@ -329,7 +330,10 @@ class GPT2LMHead(nn.Module):
|
|||||||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||||
self.predict_special_tokens = predict_special_tokens
|
self.predict_special_tokens = predict_special_tokens
|
||||||
# Export to TorchScript can't handle parameter sharing so we are cloning them.
|
# Export to TorchScript can't handle parameter sharing so we are cloning them.
|
||||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
|
if self.torchscript:
|
||||||
|
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
|
||||||
|
else:
|
||||||
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
lm_logits = self.decoder(hidden_state)
|
lm_logits = self.decoder(hidden_state)
|
||||||
@ -563,11 +567,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
|
|
||||||
outputs = block(hidden_states, layer_past, head_mask[i])
|
outputs = block(hidden_states, layer_past, head_mask[i])
|
||||||
hidden_states, present = outputs[:2]
|
hidden_states, present = outputs[:2]
|
||||||
presents += (present,)
|
presents = presents + (present,)
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
all_attentions.append(outputs[2])
|
all_attentions.append(outputs[2])
|
||||||
@ -577,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
hidden_states = hidden_states.view(*output_shape)
|
hidden_states = hidden_states.view(*output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states, presents)
|
outputs = (hidden_states, presents)
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
outputs += (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||||
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
||||||
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
||||||
outputs += (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module):
|
|||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.predict_special_tokens = config.predict_special_tokens
|
self.predict_special_tokens = config.predict_special_tokens
|
||||||
|
self.torchscript = config.torchscript
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||||
self.predict_special_tokens = predict_special_tokens
|
self.predict_special_tokens = predict_special_tokens
|
||||||
embed_shape = model_embeddings_weights.shape
|
|
||||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
|
if self.torchscript:
|
||||||
|
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
|
||||||
|
else:
|
||||||
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
lm_logits = self.decoder(hidden_state)
|
lm_logits = self.decoder(hidden_state)
|
||||||
@ -583,22 +587,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
for i, block in enumerate(self.h):
|
for i, block in enumerate(self.h):
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
|
|
||||||
outputs = block(hidden_states, head_mask[i])
|
outputs = block(hidden_states, head_mask[i])
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
all_attentions += (outputs[1],)
|
all_attentions = all_attentions + (outputs[1],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
|
|
||||||
outputs = (hidden_states.view(*output_shape),)
|
outputs = (hidden_states.view(*output_shape),)
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
outputs += (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
outputs += (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # last hidden state, (all hidden states), (all attentions)
|
return outputs # last hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@ -530,7 +530,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
outputs = (output_h, output_g)
|
outputs = (output_h, output_g)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
outputs += (attn_prob,)
|
outputs = outputs + (attn_prob,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
class XLNetFeedForward(nn.Module):
|
class XLNetFeedForward(nn.Module):
|
||||||
@ -878,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
hidden_states = []
|
hidden_states = []
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
# cache new mems
|
# cache new mems
|
||||||
new_mems += (self.cache_mem(output_h, mems[i]),)
|
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||||
|
|
||||||
@ -902,10 +902,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
|
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
|
||||||
else:
|
else:
|
||||||
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
||||||
outputs += (hidden_states,)
|
outputs = outputs + (hidden_states,)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||||
outputs += (attentions,)
|
outputs = outputs + (attentions,)
|
||||||
|
|
||||||
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||||
|
|
||||||
@ -975,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
super(XLNetLMHeadModel, self).__init__(config)
|
super(XLNetLMHeadModel, self).__init__(config)
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
|
self.torchscript = config.torchscript
|
||||||
|
|
||||||
self.transformer = XLNetModel(config)
|
self.transformer = XLNetModel(config)
|
||||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||||
@ -987,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
""" Make sure we are sharing the embeddings
|
""" Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
|
if self.torchscript:
|
||||||
|
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
|
||||||
|
else:
|
||||||
|
self.lm_loss.weight = self.transformer.word_embedding.weight
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||||
|
@ -41,6 +41,7 @@ def _create_and_check_torchscript_output_hidden_state(tester, model_classes, con
|
|||||||
|
|
||||||
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
|
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
|
||||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.torchscript = True
|
||||||
for model_class in model_classes:
|
for model_class in model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user