From 3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 28 Mar 2023 00:30:16 +0530 Subject: [PATCH] [neptune] fix checkpoint bug with relative out_dir (#22102) * [neptune] fix checkpoint bug with relative out_dir * update imports * reformat with black * check neptune without imports * fix typing-related issue * run black on code * use os.path.sep instead of raw \ * simplify imports and remove type annotation * make ruff happy * apply review suggestions --------- Co-authored-by: Aleksander Wojnarowicz --- src/transformers/integrations.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 14857f83083..52e4d92148e 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -31,6 +31,7 @@ import numpy as np from . import __version__ as version from .utils import flatten_dict, is_datasets_available, is_torch_available, logging +from .utils.versions import importlib_metadata logger = logging.get_logger(__name__) @@ -53,9 +54,19 @@ if _has_comet: except (ImportError, ValueError): _has_comet = False -_has_neptune = importlib.util.find_spec("neptune") is not None +_has_neptune = ( + importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None +) if TYPE_CHECKING and _has_neptune: - from neptune.new.metadata_containers.run import Run + try: + _neptune_version = importlib_metadata.version("neptune") + logger.info(f"Neptune version {_neptune_version} available.") + except importlib_metadata.PackageNotFoundError: + try: + _neptune_version = importlib_metadata.version("neptune-client") + logger.info(f"Neptune-client version {_neptune_version} available.") + except importlib_metadata.PackageNotFoundError: + _has_neptune = False from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 @@ -1155,7 +1166,7 @@ class NeptuneCallback(TrainerCallback): project: Optional[str] = None, name: Optional[str] = None, base_namespace: str = "finetuning", - run: Optional["Run"] = None, + run=None, log_parameters: bool = True, log_checkpoints: Optional[str] = None, **neptune_run_kwargs, @@ -1163,15 +1174,15 @@ class NeptuneCallback(TrainerCallback): if not is_neptune_available(): raise ValueError( "NeptuneCallback requires the Neptune client library to be installed. " - "To install the library, run `pip install neptune-client`." + "To install the library, run `pip install neptune`." ) - from neptune.new.metadata_containers.run import Run - try: - from neptune.new.integrations.utils import verify_type + from neptune import Run + from neptune.internal.utils import verify_type except ImportError: from neptune.new.internal.utils import verify_type + from neptune.new.metadata_containers.run import Run verify_type("api_token", api_token, (str, type(None))) verify_type("project", project, (str, type(None))) @@ -1288,7 +1299,10 @@ class NeptuneCallback(TrainerCallback): if self._volatile_checkpoints_dir is not None: consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint) try: - shutil.copytree(relative_path, os.path.join(consistent_checkpoint_path, relative_path)) + # Remove leading ../ from a relative path. + cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep) + copy_path = os.path.join(consistent_checkpoint_path, cpkt_path) + shutil.copytree(relative_path, copy_path) target_path = consistent_checkpoint_path except IOError as e: logger.warning(