mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
XLMWithLMHead fixed - standardize conversion
This commit is contained in:
parent
646711e1e2
commit
969d3ae95e
@ -57,7 +57,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
tfo = tf_model(tf_inputs, training=False)
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
|
@ -46,7 +46,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
tfo = tf_model(tf_inputs, training=False)
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
|
@ -19,34 +19,34 @@ from __future__ import (absolute_import, division, print_function,
|
|||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pytorch_transformers import is_tf_available, is_torch_available
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path):
|
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
||||||
""" Load pytorch checkpoints in a TF 2.0 model
|
""" Load pytorch checkpoints in a TF 2.0 model
|
||||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||||
"""
|
"""
|
||||||
if not is_tf_available() or not is_torch_available():
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
except ImportError as e:
|
||||||
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise ImportError
|
raise e
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||||
|
|
||||||
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
||||||
|
|
||||||
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict)
|
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
|
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
||||||
""" Load pytorch state_dict in a TF 2.0 model.
|
""" Load pytorch state_dict in a TF 2.0 model.
|
||||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||||
@ -102,6 +102,7 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
|
|||||||
|
|
||||||
K.batch_set_value(weight_value_tuples)
|
K.batch_set_value(weight_value_tuples)
|
||||||
|
|
||||||
|
if tf_inputs is not None:
|
||||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
||||||
|
@ -50,11 +50,9 @@ def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
attns_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
attns_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
||||||
langs_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
langs_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = [tf.constant(inputs_list), tf.constant(attns_list), tf.constant(langs_list)]
|
||||||
tf_attns = tf.constant(attns_list)
|
tfo = tf_model(tf_inputs, training=False)
|
||||||
tf_langs = tf.constant(langs_list)
|
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||||
tfo = tf_model([tf_inputs, tf_attns, tf_langs], training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||||
@ -614,7 +612,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFXLMMainLayer(config, name='transformer___')
|
self.transformer = TFXLMMainLayer(config, name='transformer')
|
||||||
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
|
@ -563,10 +563,10 @@ class XLMPredLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
outputs = ()
|
outputs = ()
|
||||||
if self.asm is False:
|
if self.asm is False:
|
||||||
scores = self.proj(x).view(-1, self.n_words)
|
scores = self.proj(x)
|
||||||
outputs = (scores,) + outputs
|
outputs = (scores,) + outputs
|
||||||
if y is not None:
|
if y is not None:
|
||||||
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
|
loss = F.cross_entropy(scores.view(-1, self.n_words), y, reduction='elementwise_mean')
|
||||||
outputs = (loss,) + outputs
|
outputs = (loss,) + outputs
|
||||||
else:
|
else:
|
||||||
scores = self.proj.log_prob(x)
|
scores = self.proj.log_prob(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user