mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
fix loading bug and check full conversion of model
This commit is contained in:
parent
fea15cc9f5
commit
009101de12
@ -180,16 +180,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
raise
|
raise
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
pointer.data = torch.from_numpy(array)
|
pointer.data = torch.from_numpy(array)
|
||||||
del tf_weights[name]
|
tf_weights.pop(name, None)
|
||||||
|
tf_weights.pop(name + '/Adam', None)
|
||||||
|
tf_weights.pop(name + '/Adam_1', None)
|
||||||
|
|
||||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
f.write(config.to_json_string())
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
@ -818,7 +818,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + '.')
|
load(child, prefix + name + '.')
|
||||||
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
load(model, prefix='')
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||||
model.__class__.__name__, missing_keys))
|
model.__class__.__name__, missing_keys))
|
||||||
|
Loading…
Reference in New Issue
Block a user