Ensure that integrations are imported before transformers or ml libs (#7330)

* Ensure that intergrations are imported before transformers or ml libs

* Black reformatter wanted a newline

* isort requests

* black requests

* flake8 requests
This commit is contained in:
Doug Blank 2020-09-23 13:23:45 -04:00 committed by GitHub
parent 3323146e90
commit 8c697d58ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 30 deletions

View File

@ -17,6 +17,16 @@ else:
absl.logging.set_stderrthreshold("info")
absl.logging._warn_preinit_stderr = False
# Integrations: this needs to come before other ml imports
# in order to allow any 3rd-party code to initialize properly
from .integrations import ( # isort:skip
is_comet_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
is_wandb_available,
)
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
@ -96,15 +106,6 @@ from .file_utils import (
)
from .hf_argparser import HfArgumentParser
# Integrations
from .integrations import (
is_comet_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
is_wandb_available,
)
# Model Cards
from .modelcard import ModelCard

View File

@ -1,13 +1,7 @@
# Integrations with other Python libraries
import math
import os
import numpy as np
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from .utils import logging
logger = logging.get_logger(__name__)
try:
import comet_ml # noqa: F401
@ -16,7 +10,6 @@ try:
except (ImportError):
_has_comet = False
try:
import wandb
@ -29,18 +22,6 @@ try:
except (ImportError, AttributeError):
_has_wandb = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
try:
import optuna # noqa: F401
@ -56,6 +37,29 @@ except (ImportError):
_has_ray = False
# No ML framework or transformer imports above this point
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
from .utils import logging # isort:skip
logger = logging.get_logger(__name__)
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
# Integration functions:
def is_wandb_available():
return _has_wandb
@ -135,7 +139,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs))
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}
if "reporter" not in kwargs: