mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 05:40:05 +06:00
disable deepspeed when setting up fake trainer (#38101)
* disable deepspeed when setting up fake trainer * Apply style fixes --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
7caa57e85e
commit
fe9426f12d
@ -15,6 +15,7 @@
|
|||||||
Integrations with other Python libraries.
|
Integrations with other Python libraries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import importlib.util
|
import importlib.util
|
||||||
@ -33,7 +34,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
|
||||||
from .. import PreTrainedModel, TFPreTrainedModel
|
from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments
|
||||||
from .. import __version__ as version
|
from .. import __version__ as version
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
@ -929,13 +930,17 @@ class WandbCallback(TrainerCallback):
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.setup(args, state, model, **kwargs)
|
self.setup(args, state, model, **kwargs)
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
|
||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
|
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
|
||||||
from ..trainer import Trainer
|
from ..trainer import Trainer
|
||||||
|
|
||||||
fake_trainer = Trainer(args=args, model=model, processing_class=processing_class, eval_dataset=["fake"])
|
args_for_fake = copy.deepcopy(args)
|
||||||
|
args_for_fake.deepspeed = None
|
||||||
|
fake_trainer = Trainer(
|
||||||
|
args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"]
|
||||||
|
)
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
fake_trainer.save_model(temp_dir)
|
fake_trainer.save_model(temp_dir)
|
||||||
metadata = (
|
metadata = (
|
||||||
|
Loading…
Reference in New Issue
Block a user