mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
fixed the fix. tf session madness.
This commit is contained in:
parent
edfd965ac8
commit
09ecf225e9
@ -62,34 +62,34 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
|||||||
if not os.path.isdir(ckpt_dir):
|
if not os.path.isdir(ckpt_dir):
|
||||||
os.makedirs(ckpt_dir)
|
os.makedirs(ckpt_dir)
|
||||||
|
|
||||||
session = tf.Session()
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
tf_vars = []
|
|
||||||
|
|
||||||
def to_tf_var_name(name:str):
|
def to_tf_var_name(name:str):
|
||||||
for patt, repl in iter(var_map):
|
for patt, repl in iter(var_map):
|
||||||
name = name.replace(patt, repl)
|
name = name.replace(patt, repl)
|
||||||
return 'bert/{}'.format(name)
|
return 'bert/{}'.format(name)
|
||||||
|
|
||||||
def assign_tf_var(tensor:np.ndarray, name:str):
|
def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session):
|
||||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name)
|
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||||
session.run(tf.variables_initializer([tf_var]))
|
session.run(tf.variables_initializer([tf_var]))
|
||||||
tf.keras.backend.set_value(tf_var, tensor)
|
|
||||||
session.run(tf_var)
|
session.run(tf_var)
|
||||||
return tf_var
|
return tf_var
|
||||||
|
|
||||||
for var_name in state_dict:
|
tf.reset_default_graph()
|
||||||
tf_name = to_tf_var_name(var_name)
|
with tf.Session() as session:
|
||||||
torch_tensor = state_dict[var_name].numpy()
|
for var_name in state_dict:
|
||||||
if any([x in var_name for x in tensors_to_transopse]):
|
tf_name = to_tf_var_name(var_name)
|
||||||
torch_tensor = torch_tensor.T
|
torch_tensor = state_dict[var_name].numpy()
|
||||||
tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name)
|
if any([x in var_name for x in tensors_to_transopse]):
|
||||||
tf_vars.append(tf_tensor)
|
torch_tensor = torch_tensor.T
|
||||||
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
|
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||||
|
tf.keras.backend.set_value(tf_var, torch_tensor)
|
||||||
|
tf_weight = session.run(tf_var)
|
||||||
|
print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
|
||||||
|
|
||||||
saver = tf.train.Saver(tf_vars)
|
saver = tf.train.Saver(tf.trainable_variables())
|
||||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||||
|
|
||||||
|
|
||||||
def main(raw_args=None):
|
def main(raw_args=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user