update model conversion

This commit is contained in:
thomwolf 2019-01-16 11:54:54 +01:00
parent a28dfc8659
commit fea15cc9f5
2 changed files with 15 additions and 23 deletions

View File

@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config):
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
# Softmax cutoffs
# Adaptive Softmax
tf_to_pt_map.update({
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
@ -169,14 +172,17 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
continue
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
else:
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
del tf_weights[name]
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME

View File

@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if state_dict is None:
state_dict = torch.load(resolved_archive_file)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []