mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Update unwrap from accelerate (#29933)
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
fbd8c51ffc
commit
b4fd49b6c5
@ -109,6 +109,7 @@ if is_accelerate_available():
|
|||||||
from accelerate.hooks import add_hook_to_module
|
from accelerate.hooks import add_hook_to_module
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
check_tied_parameters_on_same_device,
|
check_tied_parameters_on_same_device,
|
||||||
|
extract_model_from_parallel,
|
||||||
find_tied_parameters,
|
find_tied_parameters,
|
||||||
get_balanced_memory,
|
get_balanced_memory,
|
||||||
get_max_memory,
|
get_max_memory,
|
||||||
@ -4805,18 +4806,34 @@ class SequenceSummary(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`torch.nn.Module`): The model to unwrap.
|
model (`torch.nn.Module`): The model to unwrap.
|
||||||
|
recursive (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
|
||||||
|
recursively, not just the top-level distributed containers.
|
||||||
"""
|
"""
|
||||||
# since there could be multiple levels of wrapping, unwrap recursively
|
# Use accelerate implementation if available (should always be the case when using torch)
|
||||||
if hasattr(model, "module"):
|
# This is for pytorch, as we also have to handle things like dynamo
|
||||||
return unwrap_model(model.module)
|
if is_accelerate_available():
|
||||||
|
kwargs = {}
|
||||||
|
if recursive:
|
||||||
|
if not is_accelerate_available("0.29.0"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs["recursive"] = recursive
|
||||||
|
return extract_model_from_parallel(model, **kwargs)
|
||||||
else:
|
else:
|
||||||
return model
|
# since there could be multiple levels of wrapping, unwrap recursively
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
return unwrap_model(model.module)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def expand_device_map(device_map, param_names, start_prefix):
|
def expand_device_map(device_map, param_names, start_prefix):
|
||||||
|
@ -63,7 +63,7 @@ from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_h
|
|||||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||||
from .integrations.tpu import tpu_spmd_dataloader
|
from .integrations.tpu import tpu_spmd_dataloader
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
|
||||||
from .models.auto.modeling_auto import (
|
from .models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
@ -684,7 +684,7 @@ class Trainer:
|
|||||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
|
||||||
https://arxiv.org/abs/2310.05914
|
https://arxiv.org/abs/2310.05914
|
||||||
"""
|
"""
|
||||||
unwrapped_model = unwrap_model(model)
|
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||||
|
|
||||||
if _is_peft_model(unwrapped_model):
|
if _is_peft_model(unwrapped_model):
|
||||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||||
@ -705,7 +705,7 @@ class Trainer:
|
|||||||
if not hasattr(self, "neftune_hook_handle"):
|
if not hasattr(self, "neftune_hook_handle"):
|
||||||
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
|
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
|
||||||
|
|
||||||
unwrapped_model = unwrap_model(model)
|
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||||
|
|
||||||
if _is_peft_model(unwrapped_model):
|
if _is_peft_model(unwrapped_model):
|
||||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||||
@ -1617,7 +1617,7 @@ class Trainer:
|
|||||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||||
if unwrap_model(model) is not model:
|
if self.accelerator.unwrap_model(model) is not model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# Mixed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
@ -3165,7 +3165,7 @@ class Trainer:
|
|||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
unwrapped_model = unwrap_model(model)
|
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||||
if _is_peft_model(unwrapped_model):
|
if _is_peft_model(unwrapped_model):
|
||||||
model_name = unwrapped_model.base_model.model._get_name()
|
model_name = unwrapped_model.base_model.model._get_name()
|
||||||
else:
|
else:
|
||||||
@ -3272,8 +3272,8 @@ class Trainer:
|
|||||||
supported_classes = (PushToHubMixin,)
|
supported_classes = (PushToHubMixin,)
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
if not isinstance(model, supported_classes):
|
if not isinstance(model, supported_classes):
|
||||||
if isinstance(unwrap_model(model), supported_classes):
|
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
|
||||||
unwrap_model(model).save_pretrained(
|
self.accelerator.unwrap_model(model).save_pretrained(
|
||||||
output_dir,
|
output_dir,
|
||||||
is_main_process=self.args.should_save,
|
is_main_process=self.args.should_save,
|
||||||
state_dict=model.state_dict(),
|
state_dict=model.state_dict(),
|
||||||
@ -3311,8 +3311,8 @@ class Trainer:
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = self.model.state_dict()
|
state_dict = self.model.state_dict()
|
||||||
|
|
||||||
if isinstance(unwrap_model(self.model), supported_classes):
|
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
|
||||||
unwrap_model(self.model).save_pretrained(
|
self.accelerator.unwrap_model(self.model).save_pretrained(
|
||||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -3969,7 +3969,7 @@ class Trainer:
|
|||||||
f.write(model_card)
|
f.write(model_card)
|
||||||
|
|
||||||
if is_peft_library:
|
if is_peft_library:
|
||||||
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
|
self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
|
||||||
|
|
||||||
def _push_from_checkpoint(self, checkpoint_folder):
|
def _push_from_checkpoint(self, checkpoint_folder):
|
||||||
# Only push from one node.
|
# Only push from one node.
|
||||||
|
@ -123,7 +123,6 @@ if is_torch_available():
|
|||||||
Trainer,
|
Trainer,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import unwrap_model
|
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
|
|
||||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||||
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
|
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
|
||||||
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
self.assertGreaterEqual(
|
||||||
|
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
|
||||||
|
)
|
||||||
|
|
||||||
# with plain model
|
# with plain model
|
||||||
assert_flos_extraction(trainer, trainer.model)
|
assert_flos_extraction(trainer, trainer.model)
|
||||||
|
Loading…
Reference in New Issue
Block a user