mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TP initialization module-by-module (#35996)
* module-by-module loading! * Update modeling_utils.py * dtyle and comments * Update modeling_utils.py * Update modeling_utils.py * Update test * Update modeling_utils.py * Update modeling_utils.py * Update test_tp.py * Update test_tp.py * Update modeling_utils.py * re-trigger CIs * re-trigger CIs
This commit is contained in:
parent
0863eef248
commit
60226c6ff3
@ -787,6 +787,7 @@ def _load_state_dict_into_meta_model(
|
||||
keep_in_fp32_modules=None,
|
||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
||||
device_mesh=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@ -796,6 +797,8 @@ def _load_state_dict_into_meta_model(
|
||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||
`bert.pooler.dense.weight`
|
||||
|
||||
It also initialize tensor parallelism for each module if needed.
|
||||
|
||||
"""
|
||||
|
||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||
@ -809,6 +812,12 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
# we need this later to initialize tensor parallelism
|
||||
if device_mesh is not None:
|
||||
full_tp_plan = model.config.base_model_tp_plan
|
||||
for submodule in model.modules():
|
||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in expected_keys:
|
||||
continue
|
||||
@ -912,6 +921,37 @@ def _load_state_dict_into_meta_model(
|
||||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
# In this case, let's parallelize the modules!
|
||||
if device_mesh is not None:
|
||||
# Immediate parent
|
||||
split_parent_module_name = param_name.split(".")[:-1]
|
||||
parent_module_name = ".".join(split_parent_module_name)
|
||||
parent_module = model
|
||||
for name in split_parent_module_name:
|
||||
parent_module = getattr(parent_module, name)
|
||||
|
||||
# Check if we are part of the tp_plan
|
||||
current_module_plan = None
|
||||
for param, plan in full_tp_plan.items():
|
||||
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
|
||||
pattern = param.replace("*", "[0-9]+")
|
||||
if re.search(pattern, parent_module_name):
|
||||
current_module_plan = plan
|
||||
break
|
||||
|
||||
# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
|
||||
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
|
||||
process_device = list(device_map.values())[0]
|
||||
all_module_parameters_initialized = all(
|
||||
m.device == process_device for m in parent_module.parameters(recurse=False)
|
||||
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
|
||||
if current_module_plan is not None and all_module_parameters_initialized:
|
||||
torch.distributed.tensor.parallel.parallelize_module(
|
||||
parent_module,
|
||||
device_mesh=device_mesh,
|
||||
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
|
||||
)
|
||||
|
||||
return error_msgs, offload_index, state_dict_index
|
||||
|
||||
|
||||
@ -3489,12 +3529,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
|
||||
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
|
||||
# And temporarily setting the default device to current process rank result in the following error
|
||||
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
|
||||
tp_device = None
|
||||
# `device_map` pointing to the correct device
|
||||
device_mesh = None
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
if not torch.distributed.is_initialized():
|
||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
|
||||
@ -3506,6 +3545,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
@ -3600,7 +3643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
elif not low_cpu_mem_usage:
|
||||
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
||||
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@ -3609,7 +3652,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
elif not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
)
|
||||
|
||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||
@ -4186,6 +4229,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
@ -4336,6 +4382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_path=gguf_path,
|
||||
weights_only=weights_only,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
@ -4370,8 +4417,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary
|
||||
if device_map is not None:
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances)
|
||||
if device_map is not None and device_mesh is None:
|
||||
device_map_kwargs = {
|
||||
"device_map": device_map,
|
||||
"offload_dir": offload_folder,
|
||||
@ -4398,6 +4446,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||
# not part of the state_dict (persistent=False)
|
||||
if device_mesh is not None:
|
||||
for buffer in model.buffers():
|
||||
if buffer.device != tp_device:
|
||||
buffer.data = buffer.to(tp_device)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model, config=config)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
@ -4420,16 +4475,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
if tp_plan is not None:
|
||||
assert tp_device is not None, "tp_device not set!"
|
||||
if not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
# Apply Tensor Parallelism
|
||||
model.tensor_parallel(device_mesh)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
@ -4523,6 +4568,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=None,
|
||||
gguf_path=None,
|
||||
weights_only=True,
|
||||
device_mesh=None,
|
||||
):
|
||||
is_safetensors = False
|
||||
is_quantized = hf_quantizer is not None
|
||||
@ -4822,6 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
else:
|
||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||
@ -4911,6 +4958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
@ -5188,7 +5236,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
def tensor_parallel(self, device_mesh):
|
||||
"""
|
||||
Tensor parallelize the model across the given device mesh.
|
||||
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
|
||||
was already loaded in memory, note however that this means that each process will first initialize the whole model,
|
||||
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
|
||||
|
||||
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
|
||||
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
|
||||
|
||||
Args:
|
||||
device_mesh (`torch.distributed.DeviceMesh`):
|
||||
|
@ -81,17 +81,13 @@ class TestTensorParallel(TestCasePlus):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
|
||||
torch.distributed.barrier()
|
||||
|
||||
# The expected full model memory footprint
|
||||
expected_model_memory = 16
|
||||
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
|
||||
expected_model_memory_per_device = (16 / world_size) + 1
|
||||
overhead_factor = 1.2
|
||||
|
||||
# Assert we did not use more than the full model expected memory (with some overhead)
|
||||
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
|
||||
raise ValueError("Loading the model used more than the full model size")
|
||||
|
||||
# Assert we correctly handled the sharding between devices
|
||||
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
|
||||
raise ValueError("Each model shard is larger than what is expected.")
|
||||
# Check that we do not use more than the expected sharded size during initialization
|
||||
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
|
||||
raise ValueError("Loading the model used more than the expected fraction of model size per device")
|
||||
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
Loading…
Reference in New Issue
Block a user