[neptune] fix checkpoint bug with relative out_dir (#22102)

* [neptune] fix checkpoint bug with relative out_dir

* update imports

* reformat with black

* check neptune without imports

* fix typing-related issue

* run black on code

* use os.path.sep instead of raw \

* simplify imports and remove type annotation

* make ruff happy

* apply review suggestions

---------

Co-authored-by: Aleksander Wojnarowicz <alwojnarowicz@gmail.com>
This commit is contained in:
Kshiteej K 2023-03-28 00:30:16 +05:30 committed by GitHub
parent 19ade2426a
commit 3ec7a47664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,6 +31,7 @@ import numpy as np
from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
from .utils.versions import importlib_metadata
logger = logging.get_logger(__name__)
@ -53,9 +54,19 @@ if _has_comet:
except (ImportError, ValueError):
_has_comet = False
_has_neptune = importlib.util.find_spec("neptune") is not None
_has_neptune = (
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
)
if TYPE_CHECKING and _has_neptune:
from neptune.new.metadata_containers.run import Run
try:
_neptune_version = importlib_metadata.version("neptune")
logger.info(f"Neptune version {_neptune_version} available.")
except importlib_metadata.PackageNotFoundError:
try:
_neptune_version = importlib_metadata.version("neptune-client")
logger.info(f"Neptune-client version {_neptune_version} available.")
except importlib_metadata.PackageNotFoundError:
_has_neptune = False
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
@ -1155,7 +1166,7 @@ class NeptuneCallback(TrainerCallback):
project: Optional[str] = None,
name: Optional[str] = None,
base_namespace: str = "finetuning",
run: Optional["Run"] = None,
run=None,
log_parameters: bool = True,
log_checkpoints: Optional[str] = None,
**neptune_run_kwargs,
@ -1163,15 +1174,15 @@ class NeptuneCallback(TrainerCallback):
if not is_neptune_available():
raise ValueError(
"NeptuneCallback requires the Neptune client library to be installed. "
"To install the library, run `pip install neptune-client`."
"To install the library, run `pip install neptune`."
)
from neptune.new.metadata_containers.run import Run
try:
from neptune.new.integrations.utils import verify_type
from neptune import Run
from neptune.internal.utils import verify_type
except ImportError:
from neptune.new.internal.utils import verify_type
from neptune.new.metadata_containers.run import Run
verify_type("api_token", api_token, (str, type(None)))
verify_type("project", project, (str, type(None)))
@ -1288,7 +1299,10 @@ class NeptuneCallback(TrainerCallback):
if self._volatile_checkpoints_dir is not None:
consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
try:
shutil.copytree(relative_path, os.path.join(consistent_checkpoint_path, relative_path))
# Remove leading ../ from a relative path.
cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
shutil.copytree(relative_path, copy_path)
target_path = consistent_checkpoint_path
except IOError as e:
logger.warning(