mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1778 from eukaryote31/patch-2
from_pretrained: convert DialoGPT format
This commit is contained in:
commit
d49c43ff78
@ -118,6 +118,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
# DialoGPT format
|
||||
if key == 'lm_head.decoder.weight':
|
||||
new_key = 'lm_head.weight'
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
|
@ -427,6 +427,8 @@ class PreTrainedModel(nn.Module):
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if key == 'lm_head.decoder.weight':
|
||||
new_key = 'lm_head.weight'
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
|
Loading…
Reference in New Issue
Block a user