Update unwrap from accelerate (#29933)

* Use unwrap with the one in accelerate

* oups

* update unwrap

* fix

* wording

* raise error instead

* comment

* doc

* Update src/transformers/modeling_utils.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* style

* put else

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Marc Sun 2024-04-19 18:05:34 +02:00 committed by GitHub
parent fbd8c51ffc
commit b4fd49b6c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 18 deletions

View File

@ -109,6 +109,7 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module
from accelerate.utils import (
check_tied_parameters_on_same_device,
extract_model_from_parallel,
find_tied_parameters,
get_balanced_memory,
get_max_memory,
@ -4805,18 +4806,34 @@ class SequenceSummary(nn.Module):
return output
def unwrap_model(model: nn.Module) -> nn.Module:
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
recursive (`bool`, *optional*, defaults to `False`):
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
recursively, not just the top-level distributed containers.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
# Use accelerate implementation if available (should always be the case when using torch)
# This is for pytorch, as we also have to handle things like dynamo
if is_accelerate_available():
kwargs = {}
if recursive:
if not is_accelerate_available("0.29.0"):
raise RuntimeError(
"Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
)
else:
kwargs["recursive"] = recursive
return extract_model_from_parallel(model, **kwargs)
else:
return model
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
def expand_device_map(device_map, param_names, start_prefix):

View File

@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
@ -684,7 +684,7 @@ class Trainer:
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
@ -705,7 +705,7 @@ class Trainer:
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
@ -1617,7 +1617,7 @@ class Trainer:
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
if self.accelerator.unwrap_model(model) is not model:
return model
# Mixed precision training with apex (torch < 1.6)
@ -3165,7 +3165,7 @@ class Trainer:
self._past = outputs[self.args.past_index]
if labels is not None:
unwrapped_model = unwrap_model(model)
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
@ -3272,8 +3272,8 @@ class Trainer:
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
self.accelerator.unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
@ -3311,8 +3311,8 @@ class Trainer:
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
@ -3969,7 +3969,7 @@ class Trainer:
f.write(model_card)
if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.

View File

@ -123,7 +123,6 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig
if is_safetensors_available():
@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
)
# with plain model
assert_flos_extraction(trainer, trainer.model)