Save huggingface checkpoint as artifact in mlflow callback (#17686)

* Fix eval to compute rouge correctly for rouge_score

* styling

* moving sentence tokenization to utils from run_eval

* saving ckpt in mlflow

* use existing format of args

* fix documentation

Co-authored-by: Swetha Mandava <smandava@nvidia.com>
This commit is contained in:
Swetha Mandava 2022-06-17 13:14:03 -05:00 committed by GitHub
parent 21a772426d
commit 522a9ece4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -787,7 +787,7 @@ class MLflowCallback(TrainerCallback):
Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
@ -872,12 +872,20 @@ class MLflowCallback(TrainerCallback):
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)
if self._auto_end_run and self._ml_flow.active_run():
self._ml_flow.end_run()
def on_save(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero and self._log_artifacts:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
self._ml_flow.pyfunc.log_model(
ckpt_dir,
artifacts={"model_path": artifact_path},
python_model=self._ml_flow.pyfunc.PythonModel(),
)
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed