From fea15cc9f5939bbd1cb162921ae273da9de49c14 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 16 Jan 2019 11:54:54 +0100 Subject: [PATCH] update model conversion --- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 24 ++++++++++++------- .../modeling_transfo_xl.py | 14 ----------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index b2f8432d3a5..5b8ba99678a 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -68,7 +68,10 @@ def build_tf_to_pytorch_map(model, config): layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, }) - # Softmax cutoffs + # Adaptive Softmax + tf_to_pt_map.update({ + "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, + "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias}) for i, (out_l, proj_l, tie_proj) in enumerate(zip( model.crit.out_layers, model.crit.out_projs, @@ -169,14 +172,17 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, raise print("Initialize PyTorch weight {} for layer {}".format(name, i)) p_i.data = torch.from_numpy(arr_i) - continue - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) + else: + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + del tf_weights[name] + + print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 0e1f3f82401..de0430e9644 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -802,20 +802,6 @@ class TransfoXLPreTrainedModel(nn.Module): if state_dict is None: state_dict = torch.load(resolved_archive_file) - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - missing_keys = [] unexpected_keys = [] error_msgs = []