mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[neptune] fix checkpoint bug with relative out_dir (#22102)
* [neptune] fix checkpoint bug with relative out_dir * update imports * reformat with black * check neptune without imports * fix typing-related issue * run black on code * use os.path.sep instead of raw \ * simplify imports and remove type annotation * make ruff happy * apply review suggestions --------- Co-authored-by: Aleksander Wojnarowicz <alwojnarowicz@gmail.com>
This commit is contained in:
parent
19ade2426a
commit
3ec7a47664
@ -31,6 +31,7 @@ import numpy as np
|
||||
|
||||
from . import __version__ as version
|
||||
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
|
||||
from .utils.versions import importlib_metadata
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -53,9 +54,19 @@ if _has_comet:
|
||||
except (ImportError, ValueError):
|
||||
_has_comet = False
|
||||
|
||||
_has_neptune = importlib.util.find_spec("neptune") is not None
|
||||
_has_neptune = (
|
||||
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
|
||||
)
|
||||
if TYPE_CHECKING and _has_neptune:
|
||||
from neptune.new.metadata_containers.run import Run
|
||||
try:
|
||||
_neptune_version = importlib_metadata.version("neptune")
|
||||
logger.info(f"Neptune version {_neptune_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_neptune_version = importlib_metadata.version("neptune-client")
|
||||
logger.info(f"Neptune-client version {_neptune_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_has_neptune = False
|
||||
|
||||
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
@ -1155,7 +1166,7 @@ class NeptuneCallback(TrainerCallback):
|
||||
project: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
base_namespace: str = "finetuning",
|
||||
run: Optional["Run"] = None,
|
||||
run=None,
|
||||
log_parameters: bool = True,
|
||||
log_checkpoints: Optional[str] = None,
|
||||
**neptune_run_kwargs,
|
||||
@ -1163,15 +1174,15 @@ class NeptuneCallback(TrainerCallback):
|
||||
if not is_neptune_available():
|
||||
raise ValueError(
|
||||
"NeptuneCallback requires the Neptune client library to be installed. "
|
||||
"To install the library, run `pip install neptune-client`."
|
||||
"To install the library, run `pip install neptune`."
|
||||
)
|
||||
|
||||
from neptune.new.metadata_containers.run import Run
|
||||
|
||||
try:
|
||||
from neptune.new.integrations.utils import verify_type
|
||||
from neptune import Run
|
||||
from neptune.internal.utils import verify_type
|
||||
except ImportError:
|
||||
from neptune.new.internal.utils import verify_type
|
||||
from neptune.new.metadata_containers.run import Run
|
||||
|
||||
verify_type("api_token", api_token, (str, type(None)))
|
||||
verify_type("project", project, (str, type(None)))
|
||||
@ -1288,7 +1299,10 @@ class NeptuneCallback(TrainerCallback):
|
||||
if self._volatile_checkpoints_dir is not None:
|
||||
consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
|
||||
try:
|
||||
shutil.copytree(relative_path, os.path.join(consistent_checkpoint_path, relative_path))
|
||||
# Remove leading ../ from a relative path.
|
||||
cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
|
||||
copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
|
||||
shutil.copytree(relative_path, copy_path)
|
||||
target_path = consistent_checkpoint_path
|
||||
except IOError as e:
|
||||
logger.warning(
|
||||
|
Loading…
Reference in New Issue
Block a user