allowing from_pretrained to load from url directly

This commit is contained in:
thomwolf 2019-12-11 17:19:18 +01:00 committed by Morgan Funtowicz
parent c28273793e
commit 6709739a05
2 changed files with 8 additions and 3 deletions

View File

@ -265,8 +265,10 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:

View File

@ -364,9 +364,12 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index")
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try: