diff --git a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py index 79b5f41adcf..120624bc1b4 100755 --- a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py @@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor l = re.split(r'_(\d+)', m_name) else: l = [m_name] - if l[0] == 'kernel': + if l[0] == 'kernel' or l[0] == 'gamma': pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias': + elif l[0] == 'output_bias' or l[0] == 'beta': pointer = getattr(pointer, 'bias') elif l[0] == 'output_weights': pointer = getattr(pointer, 'weight') diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 0699f671999..c6940c74eb2 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -516,9 +516,9 @@ class PreTrainedBertModel(nn.Module): for key in state_dict.keys(): new_key = None if 'gamma' in key: - new_key = key.replace('gamma','weight') + new_key = key.replace('gamma', 'weight') if 'beta' in key: - new_key = key.replace('beta','bias') + new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) diff --git a/tests/optimization_test.py b/tests/optimization_test.py index 848b9d1cf5c..18463735915 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase): criterion = torch.nn.MSELoss() # No warmup, constant schedule, no gradient clipping optimizer = BertAdam(params=[w], lr=2e-1, - weight_decay=0.0, + weight_decay_rate=0.0, max_grad_norm=-1) for _ in range(100): loss = criterion(w, target)