mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
probably ok weights convertion script
This commit is contained in:
parent
ab0e8932a8
commit
960ef4df3b
@ -9,6 +9,7 @@ import re
|
|||||||
import argparse
|
import argparse
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from modeling_pytorch import BertConfig, BertModel
|
from modeling_pytorch import BertConfig, BertModel
|
||||||
|
|
||||||
@ -55,7 +56,11 @@ def convert():
|
|||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name[5:] # skip "bert/"
|
name = name[5:] # skip "bert/"
|
||||||
|
print("Loading {}".format(name))
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
|
if name[0] in ['redictions', 'eq_relationship']:
|
||||||
|
print("Skipping")
|
||||||
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||||
@ -71,8 +76,8 @@ def convert():
|
|||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
if m_name[-11:] == '_embeddings':
|
if m_name[-11:] == '_embeddings':
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, 'weight')
|
||||||
# elif m_name == 'kernel':
|
elif m_name == 'kernel':
|
||||||
# pointer = getattr(pointer, 'weight')
|
array = np.transpose(array)
|
||||||
try:
|
try:
|
||||||
assert pointer.shape == array.shape
|
assert pointer.shape == array.shape
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user