fix loading bug and check full conversion of model

This commit is contained in:
thomwolf 2019-01-16 12:16:20 +01:00
parent fea15cc9f5
commit 009101de12
2 changed files with 8 additions and 6 deletions

View File

@ -180,16 +180,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
del tf_weights[name]
tf_weights.pop(name, None)
tf_weights.pop(name + '/Adam', None)
tf_weights.pop(name + '/Adam_1', None)
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())

View File

@ -818,7 +818,7 @@ class TransfoXLPreTrainedModel(nn.Module):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
load(model, prefix='')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))