mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
fix loading bug and check full conversion of model
This commit is contained in:
parent
fea15cc9f5
commit
009101de12
@ -180,16 +180,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
del tf_weights[name]
|
||||
tf_weights.pop(name, None)
|
||||
tf_weights.pop(name + '/Adam', None)
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
||||
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
@ -818,7 +818,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||
load(model, prefix='')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
|
Loading…
Reference in New Issue
Block a user