mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Save huggingface checkpoint as artifact in mlflow callback (#17686)
* Fix eval to compute rouge correctly for rouge_score * styling * moving sentence tokenization to utils from run_eval * saving ckpt in mlflow * use existing format of args * fix documentation Co-authored-by: Swetha Mandava <smandava@nvidia.com>
This commit is contained in:
parent
21a772426d
commit
522a9ece4b
@ -787,7 +787,7 @@ class MLflowCallback(TrainerCallback):
|
|||||||
Environment:
|
Environment:
|
||||||
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
|
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
|
||||||
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
|
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
|
||||||
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
|
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
|
||||||
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
|
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
|
||||||
storage will just copy the files to your artifact location.
|
storage will just copy the files to your artifact location.
|
||||||
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
|
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
|
||||||
@ -872,12 +872,20 @@ class MLflowCallback(TrainerCallback):
|
|||||||
|
|
||||||
def on_train_end(self, args, state, control, **kwargs):
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
if self._initialized and state.is_world_process_zero:
|
if self._initialized and state.is_world_process_zero:
|
||||||
if self._log_artifacts:
|
|
||||||
logger.info("Logging artifacts. This may take time.")
|
|
||||||
self._ml_flow.log_artifacts(args.output_dir)
|
|
||||||
if self._auto_end_run and self._ml_flow.active_run():
|
if self._auto_end_run and self._ml_flow.active_run():
|
||||||
self._ml_flow.end_run()
|
self._ml_flow.end_run()
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
if self._initialized and state.is_world_process_zero and self._log_artifacts:
|
||||||
|
ckpt_dir = f"checkpoint-{state.global_step}"
|
||||||
|
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
||||||
|
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
|
||||||
|
self._ml_flow.pyfunc.log_model(
|
||||||
|
ckpt_dir,
|
||||||
|
artifacts={"model_path": artifact_path},
|
||||||
|
python_model=self._ml_flow.pyfunc.PythonModel(),
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# if the previous run is not terminated correctly, the fluent API will
|
# if the previous run is not terminated correctly, the fluent API will
|
||||||
# not let you start a new run before the previous one is killed
|
# not let you start a new run before the previous one is killed
|
||||||
|
Loading…
Reference in New Issue
Block a user