diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 6c48f3eed2a..95c29693d85 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -265,8 +265,10 @@ class TFPreTrainedModel(tf.keras.Model): pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + archive_file = pretrained_model_name_or_path + ".index" else: - raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path)) + archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 398172a88c5..eec9034fd7c 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -364,9 +364,12 @@ class PreTrainedModel(nn.Module): pretrained_model_name_or_path)) elif os.path.isfile(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path - else: - assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path) + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( + pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: