diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index f18eb43c789..3efefe712d7 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -15,6 +15,7 @@ Integrations with other Python libraries. """ +import copy import functools import importlib.metadata import importlib.util @@ -33,7 +34,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union import numpy as np import packaging.version -from .. import PreTrainedModel, TFPreTrainedModel +from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments from .. import __version__ as version from ..utils import ( PushToHubMixin, @@ -929,13 +930,17 @@ class WandbCallback(TrainerCallback): if not self._initialized: self.setup(args, state, model, **kwargs) - def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs): + def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs): if self._wandb is None: return if self._log_model.is_enabled and self._initialized and state.is_world_process_zero: from ..trainer import Trainer - fake_trainer = Trainer(args=args, model=model, processing_class=processing_class, eval_dataset=["fake"]) + args_for_fake = copy.deepcopy(args) + args_for_fake.deepspeed = None + fake_trainer = Trainer( + args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"] + ) with tempfile.TemporaryDirectory() as temp_dir: fake_trainer.save_model(temp_dir) metadata = (