diff --git a/convert_tf_checkpoint.py b/convert_tf_checkpoint.py index b8fb49bf1f1..522e9724d2d 100644 --- a/convert_tf_checkpoint.py +++ b/convert_tf_checkpoint.py @@ -9,6 +9,7 @@ import re import argparse import tensorflow as tf import torch +import numpy as np from modeling_pytorch import BertConfig, BertModel @@ -55,7 +56,11 @@ def convert(): for name, array in zip(names, arrays): name = name[5:] # skip "bert/" + print("Loading {}".format(name)) name = name.split('/') + if name[0] in ['redictions', 'eq_relationship']: + print("Skipping") + continue pointer = model for m_name in name: if re.fullmatch(r'[A-Za-z]+_\d+', m_name): @@ -71,8 +76,8 @@ def convert(): pointer = pointer[num] if m_name[-11:] == '_embeddings': pointer = getattr(pointer, 'weight') - # elif m_name == 'kernel': - # pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: