diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index d1776e9c14d..d20eafe2e91 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file tf_inputs = tf.constant(inputs_list) tfo = tf_model(tf_inputs, training=False) # build the network - pt_model = pt_model_class.from_pretrained(None, - config=config, - state_dict=torch.load(pytorch_checkpoint_path, - map_location='cpu')) + pt_model = pt_model_class(config) + pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'), + strict-False) + pt_model.eval() + pt_inputs = torch.tensor(inputs_list) with torch.no_grad(): pto = pt_model(pt_inputs) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 398172a88c5..3ac568771e3 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module): model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path: + if pretrained_model_name_or_path is not None and ( + "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path): logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " + "https://github.com/google-research/google-research/issues/119 for more information.")