From 009101de12d5cb26e8cc0f1f6af9216b11c6a50f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 16 Jan 2019 12:16:20 +0100 Subject: [PATCH] fix loading bug and check full conversion of model --- .../convert_transfo_xl_checkpoint_to_pytorch.py | 12 +++++++----- pytorch_pretrained_bert/modeling_transfo_xl.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index 5b8ba99678a..223bbec963b 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -180,16 +180,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) - del tf_weights[name] + tf_weights.pop(name, None) + tf_weights.pop(name + '/Adam', None) + tf_weights.pop(name + '/Adam_1', None) print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME - print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) + pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) + pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) + print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) torch.save(model.state_dict(), pytorch_weights_dump_path) - print("Save configuration file to {}".format(pytorch_config_dump_path)) + print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string()) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index de0430e9644..452f2e03ea1 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -818,7 +818,7 @@ class TransfoXLPreTrainedModel(nn.Module): for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') - # load(model.transformer if hasattr(model, 'transformer') else model, prefix='') + load(model, prefix='') if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys))