Find module name in an OS-agnostic fashion (#24526)

* Find module name in an OS-agnostic fashion

* address review comment
This commit is contained in:
Sylvain Gugger 2023-06-27 13:21:19 -04:00 committed by GitHub
parent 7d150d68ff
commit 38db04ece0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -234,7 +234,7 @@ def get_cached_module_file(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
submodule = os.path.basename(pretrained_model_name_or_path)
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
cached_module = try_to_load_from_cache(
@ -271,7 +271,7 @@ def get_cached_module_file(
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
if submodule == os.path.basename(pretrained_model_name_or_path):
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
# has changed since last copy.
if not (submodule_path / module_file).exists() or not filecmp.cmp(