TP initialization module-by-module (#35996)

* module-by-module loading!

* Update modeling_utils.py

* dtyle and comments

* Update modeling_utils.py

* Update modeling_utils.py

* Update test

* Update modeling_utils.py

* Update modeling_utils.py

* Update test_tp.py

* Update test_tp.py

* Update modeling_utils.py

* re-trigger CIs

* re-trigger CIs
This commit is contained in:
Cyril Vallez 2025-02-19 14:04:57 +01:00 committed by GitHub
parent 0863eef248
commit 60226c6ff3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 29 deletions

View File

@ -787,6 +787,7 @@ def _load_state_dict_into_meta_model(
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
device_mesh=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@ -796,6 +797,8 @@ def _load_state_dict_into_meta_model(
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
It also initialize tensor parallelism for each module if needed.
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
@ -809,6 +812,12 @@ def _load_state_dict_into_meta_model(
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
for param_name, param in state_dict.items():
if param_name not in expected_keys:
continue
@ -912,6 +921,37 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return
# In this case, let's parallelize the modules!
if device_mesh is not None:
# Immediate parent
split_parent_module_name = param_name.split(".")[:-1]
parent_module_name = ".".join(split_parent_module_name)
parent_module = model
for name in split_parent_module_name:
parent_module = getattr(parent_module, name)
# Check if we are part of the tp_plan
current_module_plan = None
for param, plan in full_tp_plan.items():
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
pattern = param.replace("*", "[0-9]+")
if re.search(pattern, parent_module_name):
current_module_plan = plan
break
# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
process_device = list(device_map.values())[0]
all_module_parameters_initialized = all(
m.device == process_device for m in parent_module.parameters(recurse=False)
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
if current_module_plan is not None and all_module_parameters_initialized:
torch.distributed.tensor.parallel.parallelize_module(
parent_module,
device_mesh=device_mesh,
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
)
return error_msgs, offload_index, state_dict_index
@ -3489,12 +3529,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
# And temporarily setting the default device to current process rank result in the following error
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
tp_device = None
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
@ -3506,6 +3545,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
if is_fsdp_enabled():
low_cpu_mem_usage = True
@ -3600,7 +3643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
if is_deepspeed_zero3_enabled():
@ -3609,7 +3652,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
elif not is_accelerate_available():
raise ImportError(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
@ -4186,6 +4229,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@ -4336,6 +4382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
device_mesh=device_mesh,
)
# make sure token embedding weights are still tied if needed
@ -4370,8 +4417,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
pass
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)
if device_map is not None and device_mesh is None:
device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
@ -4398,6 +4446,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
if device_mesh is not None:
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer
@ -4420,16 +4475,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
}
return model, loading_info
if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)
return model
@staticmethod
@ -4523,6 +4568,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=None,
gguf_path=None,
weights_only=True,
device_mesh=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
@ -4822,6 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
@ -4911,6 +4958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
error_msgs += new_error_msgs
else:
@ -5188,7 +5236,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
was already loaded in memory, note however that this means that each process will first initialize the whole model,
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
Args:
device_mesh (`torch.distributed.DeviceMesh`):

View File

@ -81,17 +81,13 @@ class TestTensorParallel(TestCasePlus):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
torch.distributed.barrier()
# The expected full model memory footprint
expected_model_memory = 16
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
expected_model_memory_per_device = (16 / world_size) + 1
overhead_factor = 1.2
# Assert we did not use more than the full model expected memory (with some overhead)
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
raise ValueError("Loading the model used more than the full model size")
# Assert we correctly handled the sharding between devices
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
raise ValueError("Each model shard is larger than what is expected.")
# Check that we do not use more than the expected sharded size during initialization
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
raise ValueError("Loading the model used more than the expected fraction of model size per device")
torch.distributed.barrier()
torch.distributed.destroy_process_group()