mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
allowing from_pretrained to load from url directly
This commit is contained in:
parent
c28273793e
commit
6709739a05
@ -265,8 +265,10 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
|
||||
archive_file = pretrained_model_name_or_path
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
@ -364,9 +364,12 @@ class PreTrainedModel(nn.Module):
|
||||
pretrained_model_name_or_path))
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||
pretrained_model_name_or_path + ".index")
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user