diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 2194d113e50..57bc0251fbe 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -592,6 +592,7 @@ class CometCallback(TrainerCallback): if not _has_comet: raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.") self._initialized = False + self._log_assets = False def setup(self, args, state, model): """ @@ -599,26 +600,35 @@ class CometCallback(TrainerCallback): Environment: COMET_MODE (:obj:`str`, `optional`): - "OFFLINE", "ONLINE", or "DISABLED" + Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE", + or "DISABLED". Defaults to "ONLINE". COMET_PROJECT_NAME (:obj:`str`, `optional`): - Comet.ml project name for experiments + Comet project name for experiments COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`): Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE" + COMET_LOG_ASSETS (:obj:`str`, `optional`): + Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or + "FALSE". Defaults to "TRUE". For a number of configurable items in the environment, see `here `__. """ self._initialized = True + log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper() + if log_assets in {"TRUE", "1"}: + self._log_assets = True if state.is_world_process_zero: comet_mode = os.getenv("COMET_MODE", "ONLINE").upper() - args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")} experiment = None + experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")} if comet_mode == "ONLINE": - experiment = comet_ml.Experiment(**args) + experiment = comet_ml.Experiment(**experiment_kwargs) + experiment.log_other("Created from", "transformers") logger.info("Automatic Comet.ml online logging enabled") elif comet_mode == "OFFLINE": - args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./") - experiment = comet_ml.OfflineExperiment(**args) + experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./") + experiment = comet_ml.OfflineExperiment(**experiment_kwargs) + experiment.log_other("Created from", "transformers") logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished") if experiment is not None: experiment._set_model_graph(model, framework="transformers") @@ -638,6 +648,16 @@ class CometCallback(TrainerCallback): if experiment is not None: experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers") + def on_train_end(self, args, state, control, **kwargs): + if self._initialized and state.is_world_process_zero: + experiment = comet_ml.config.get_global_experiment() + if (experiment is not None) and (self._log_assets is True): + logger.info("Logging checkpoints. This may take time.") + experiment.log_asset_folder( + args.output_dir, recursive=True, log_file_name=True, step=state.global_step + ) + experiment.end() + class AzureMLCallback(TrainerCallback): """