mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
spelling mistake
This commit is contained in:
parent
09ecf225e9
commit
c90119e543
@ -41,7 +41,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
||||
N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transopse = (
|
||||
tensors_to_transpose = (
|
||||
"dense.weight",
|
||||
"attention.self.query",
|
||||
"attention.self.key",
|
||||
@ -81,7 +81,7 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any([x in var_name for x in tensors_to_transopse]):
|
||||
if any([x in var_name for x in tensors_to_transpose]):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf.keras.backend.set_value(tf_var, torch_tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user