mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update unwrap from accelerate (#29933)
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
fbd8c51ffc
commit
b4fd49b6c5
@ -109,6 +109,7 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module
|
||||
from accelerate.utils import (
|
||||
check_tied_parameters_on_same_device,
|
||||
extract_model_from_parallel,
|
||||
find_tied_parameters,
|
||||
get_balanced_memory,
|
||||
get_max_memory,
|
||||
@ -4805,18 +4806,34 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
||||
"""
|
||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to unwrap.
|
||||
recursive (`bool`, *optional*, defaults to `False`):
|
||||
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
|
||||
recursively, not just the top-level distributed containers.
|
||||
"""
|
||||
# since there could be multiple levels of wrapping, unwrap recursively
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
# Use accelerate implementation if available (should always be the case when using torch)
|
||||
# This is for pytorch, as we also have to handle things like dynamo
|
||||
if is_accelerate_available():
|
||||
kwargs = {}
|
||||
if recursive:
|
||||
if not is_accelerate_available("0.29.0"):
|
||||
raise RuntimeError(
|
||||
"Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
|
||||
)
|
||||
else:
|
||||
kwargs["recursive"] = recursive
|
||||
return extract_model_from_parallel(model, **kwargs)
|
||||
else:
|
||||
return model
|
||||
# since there could be multiple levels of wrapping, unwrap recursively
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def expand_device_map(device_map, param_names, start_prefix):
|
||||
|
@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||
from .integrations.tpu import tpu_spmd_dataloader
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
|
||||
from .models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_MAPPING_NAMES,
|
||||
@ -684,7 +684,7 @@ class Trainer:
|
||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
||||
https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
unwrapped_model = unwrap_model(model)
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
|
||||
if _is_peft_model(unwrapped_model):
|
||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||
@ -705,7 +705,7 @@ class Trainer:
|
||||
if not hasattr(self, "neftune_hook_handle"):
|
||||
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
|
||||
|
||||
unwrapped_model = unwrap_model(model)
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
|
||||
if _is_peft_model(unwrapped_model):
|
||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||
@ -1617,7 +1617,7 @@ class Trainer:
|
||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||
|
||||
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||
if unwrap_model(model) is not model:
|
||||
if self.accelerator.unwrap_model(model) is not model:
|
||||
return model
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
@ -3165,7 +3165,7 @@ class Trainer:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if labels is not None:
|
||||
unwrapped_model = unwrap_model(model)
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
if _is_peft_model(unwrapped_model):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
@ -3272,8 +3272,8 @@ class Trainer:
|
||||
supported_classes = (PushToHubMixin,)
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
if not isinstance(model, supported_classes):
|
||||
if isinstance(unwrap_model(model), supported_classes):
|
||||
unwrap_model(model).save_pretrained(
|
||||
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
|
||||
self.accelerator.unwrap_model(model).save_pretrained(
|
||||
output_dir,
|
||||
is_main_process=self.args.should_save,
|
||||
state_dict=model.state_dict(),
|
||||
@ -3311,8 +3311,8 @@ class Trainer:
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(unwrap_model(self.model), supported_classes):
|
||||
unwrap_model(self.model).save_pretrained(
|
||||
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
|
||||
self.accelerator.unwrap_model(self.model).save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
else:
|
||||
@ -3969,7 +3969,7 @@ class Trainer:
|
||||
f.write(model_card)
|
||||
|
||||
if is_peft_library:
|
||||
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
|
||||
self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
|
||||
|
||||
def _push_from_checkpoint(self, checkpoint_folder):
|
||||
# Only push from one node.
|
||||
|
@ -123,7 +123,6 @@ if is_torch_available():
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
if is_safetensors_available():
|
||||
@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(
|
||||
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
|
||||
)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
Loading…
Reference in New Issue
Block a user