reraise EnvironmentError in from_pretrained functions of Model and Tokenizer

This commit is contained in:
Abhishek Rao 2019-08-22 15:25:40 -07:00
parent 14eef67eb2
commit c603d099aa
2 changed files with 4 additions and 4 deletions

View File

@ -473,7 +473,7 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
@ -486,7 +486,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:

View File

@ -293,7 +293,7 @@ class PreTrainedTokenizer(object):
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
except EnvironmentError as e:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")
else:
@ -303,7 +303,7 @@ class PreTrainedTokenizer(object):
"at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, str(vocab_files.keys())))
return None
raise e
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]: