mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Ensure that integrations are imported before transformers or ml libs (#7330)
* Ensure that intergrations are imported before transformers or ml libs * Black reformatter wanted a newline * isort requests * black requests * flake8 requests
This commit is contained in:
parent
3323146e90
commit
8c697d58ef
@ -17,6 +17,16 @@ else:
|
||||
absl.logging.set_stderrthreshold("info")
|
||||
absl.logging._warn_preinit_stderr = False
|
||||
|
||||
# Integrations: this needs to come before other ml imports
|
||||
# in order to allow any 3rd-party code to initialize properly
|
||||
from .integrations import ( # isort:skip
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
# Configurations
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||
@ -96,15 +106,6 @@ from .file_utils import (
|
||||
)
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
from .integrations import (
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
# Model Cards
|
||||
from .modelcard import ModelCard
|
||||
|
||||
|
@ -1,13 +1,7 @@
|
||||
# Integrations with other Python libraries
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
@ -16,7 +10,6 @@ try:
|
||||
except (ImportError):
|
||||
_has_comet = False
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
@ -29,18 +22,6 @@ try:
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
try:
|
||||
import optuna # noqa: F401
|
||||
|
||||
@ -56,6 +37,29 @@ except (ImportError):
|
||||
_has_ray = False
|
||||
|
||||
|
||||
# No ML framework or transformer imports above this point
|
||||
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
|
||||
from .utils import logging # isort:skip
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
# Integration functions:
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
@ -135,7 +139,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
n_jobs = int(kwargs.pop("n_jobs", 1))
|
||||
num_gpus_per_trial = trainer.args.n_gpu
|
||||
if num_gpus_per_trial / n_jobs >= 1:
|
||||
num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs))
|
||||
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
|
||||
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
|
||||
|
||||
if "reporter" not in kwargs:
|
||||
|
Loading…
Reference in New Issue
Block a user