mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update model conversion
This commit is contained in:
parent
a28dfc8659
commit
fea15cc9f5
@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config):
|
||||
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
||||
})
|
||||
|
||||
# Softmax cutoffs
|
||||
# Adaptive Softmax
|
||||
tf_to_pt_map.update({
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
|
||||
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
|
||||
model.crit.out_layers,
|
||||
model.crit.out_projs,
|
||||
@ -169,14 +172,17 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
continue
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
del tf_weights[name]
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
|
@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
if state_dict is None:
|
||||
state_dict = torch.load(resolved_archive_file)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
Loading…
Reference in New Issue
Block a user