mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
clean up apex integration
This commit is contained in:
parent
4946c2c500
commit
52c53f39d0
@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
l = re.split(r'_(\d+)', m_name)
|
l = re.split(r'_(\d+)', m_name)
|
||||||
else:
|
else:
|
||||||
l = [m_name]
|
l = [m_name]
|
||||||
if l[0] == 'kernel':
|
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, 'weight')
|
||||||
elif l[0] == 'output_bias':
|
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||||
pointer = getattr(pointer, 'bias')
|
pointer = getattr(pointer, 'bias')
|
||||||
elif l[0] == 'output_weights':
|
elif l[0] == 'output_weights':
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, 'weight')
|
||||||
|
@ -516,9 +516,9 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
new_key = None
|
new_key = None
|
||||||
if 'gamma' in key:
|
if 'gamma' in key:
|
||||||
new_key = key.replace('gamma','weight')
|
new_key = key.replace('gamma', 'weight')
|
||||||
if 'beta' in key:
|
if 'beta' in key:
|
||||||
new_key = key.replace('beta','bias')
|
new_key = key.replace('beta', 'bias')
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
new_keys.append(new_key)
|
new_keys.append(new_key)
|
||||||
|
@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase):
|
|||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
# No warmup, constant schedule, no gradient clipping
|
# No warmup, constant schedule, no gradient clipping
|
||||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||||
weight_decay=0.0,
|
weight_decay_rate=0.0,
|
||||||
max_grad_norm=-1)
|
max_grad_norm=-1)
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
loss = criterion(w, target)
|
loss = criterion(w, target)
|
||||||
|
Loading…
Reference in New Issue
Block a user