disable deepspeed when setting up fake trainer (#38101)

* disable deepspeed when setting up fake trainer

* Apply style fixes

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Wing Lian 2025-05-15 09:34:04 -04:00 committed by GitHub
parent 7caa57e85e
commit fe9426f12d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@
Integrations with other Python libraries. Integrations with other Python libraries.
""" """
import copy
import functools import functools
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
@ -33,7 +34,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np import numpy as np
import packaging.version import packaging.version
from .. import PreTrainedModel, TFPreTrainedModel from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments
from .. import __version__ as version from .. import __version__ as version
from ..utils import ( from ..utils import (
PushToHubMixin, PushToHubMixin,
@ -929,13 +930,17 @@ class WandbCallback(TrainerCallback):
if not self._initialized: if not self._initialized:
self.setup(args, state, model, **kwargs) self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs): def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero: if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, processing_class=processing_class, eval_dataset=["fake"]) args_for_fake = copy.deepcopy(args)
args_for_fake.deepspeed = None
fake_trainer = Trainer(
args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"]
)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir) fake_trainer.save_model(temp_dir)
metadata = ( metadata = (