Improvements to Comet Integration (#14680)

* change args to address overwriting issue

* remove project name from args

* remove passing args as kwargs to experiment object

* remove passing args as kwargs to offline experiment

* fix offline directory assignment in experiment kwargs

* log checkpoint folder on training end

* log entire output_dir as asset folder

* log asset folder  recursively

* end experiment at the end of training

* clean up

* clean up

* Default to always log training assets to Comet when using CometCallback

* change logging training assets to be true when running callback setup

* fix so that experiment always ends when training ends

* styling and quality fixes

* update docstring for COMET_LOG_ASSETS environment variable

* run styling and quality checks

* clean up to docstring

* remove merge markers

* change asset logging to false to avoid hitting max assets per experiment limit

* update training asset description

* fix styling
This commit is contained in:
Dhruv Nair 2021-12-09 00:09:10 +05:30 committed by GitHub
parent 4ea19de80c
commit fe06f8dcac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -592,6 +592,7 @@ class CometCallback(TrainerCallback):
if not _has_comet:
raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
self._initialized = False
self._log_assets = False
def setup(self, args, state, model):
"""
@ -599,26 +600,35 @@ class CometCallback(TrainerCallback):
Environment:
COMET_MODE (:obj:`str`, `optional`):
"OFFLINE", "ONLINE", or "DISABLED"
Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
or "DISABLED". Defaults to "ONLINE".
COMET_PROJECT_NAME (:obj:`str`, `optional`):
Comet.ml project name for experiments
Comet project name for experiments
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
COMET_LOG_ASSETS (:obj:`str`, `optional`):
Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
"FALSE". Defaults to "TRUE".
For a number of configurable items in the environment, see `here
<https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
"""
self._initialized = True
log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
if log_assets in {"TRUE", "1"}:
self._log_assets = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
experiment = comet_ml.Experiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
@ -638,6 +648,16 @@ class CometCallback(TrainerCallback):
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if (experiment is not None) and (self._log_assets is True):
logger.info("Logging checkpoints. This may take time.")
experiment.log_asset_folder(
args.output_dir, recursive=True, log_file_name=True, step=state.global_step
)
experiment.end()
class AzureMLCallback(TrainerCallback):
"""