move wandb/comet logger init to train() to allow parallel logging (#6850)

* move wandb/comet logger init to train() to allow parallel logging

* Setup wandb/comet loggers on first call to log()
This commit is contained in:
krfricke 2020-09-03 16:49:14 +01:00 committed by GitHub
parent 39ed68d597
commit 0f360d3d1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -255,20 +255,10 @@ class Trainer:
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
@ -518,6 +508,25 @@ class Trainer:
"""
return len(dataloader.dataset)
def _setup_loggers(self):
if self._loggers_initialized:
return
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
self._loggers_initialized = True
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
if self.hp_search_backend is None or trial is None:
@ -903,6 +912,9 @@ class Trainer:
iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on.
"""
# Set up loggers like W&B or Comet ML
self._setup_loggers()
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",