mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 21:18:21 +06:00
clean up a little bit PT <=> TF conversion
This commit is contained in:
parent
bebaa14039
commit
f8fb4335c9
@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||||
|
|
||||||
pt_model = pt_model_class.from_pretrained(None,
|
pt_model = pt_model_class(config)
|
||||||
config=config,
|
pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
|
||||||
state_dict=torch.load(pytorch_checkpoint_path,
|
strict-False)
|
||||||
map_location='cpu'))
|
pt_model.eval()
|
||||||
|
|
||||||
pt_inputs = torch.tensor(inputs_list)
|
pt_inputs = torch.tensor(inputs_list)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(pt_inputs)
|
pto = pt_model(pt_inputs)
|
||||||
|
@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
|
if pretrained_model_name_or_path is not None and (
|
||||||
|
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
|
||||||
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
||||||
"https://github.com/google-research/google-research/issues/119 for more information.")
|
"https://github.com/google-research/google-research/issues/119 for more information.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user