clean up a little bit PT <=> TF conversion

This commit is contained in:
thomwolf 2019-12-05 15:19:32 +01:00
parent bebaa14039
commit f8fb4335c9
2 changed files with 7 additions and 5 deletions

View File

@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
pt_model = pt_model_class(config)
pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
strict-False)
pt_model.eval()
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
pto = pt_model(pt_inputs)

View File

@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
if pretrained_model_name_or_path is not None and (
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.")