mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
TP initialization module-by-module (#35996)
* module-by-module loading! * Update modeling_utils.py * dtyle and comments * Update modeling_utils.py * Update modeling_utils.py * Update test * Update modeling_utils.py * Update modeling_utils.py * Update test_tp.py * Update test_tp.py * Update modeling_utils.py * re-trigger CIs * re-trigger CIs
This commit is contained in:
parent
0863eef248
commit
60226c6ff3
@ -787,6 +787,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
||||||
|
device_mesh=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||||
@ -796,6 +797,8 @@ def _load_state_dict_into_meta_model(
|
|||||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||||
`bert.pooler.dense.weight`
|
`bert.pooler.dense.weight`
|
||||||
|
|
||||||
|
It also initialize tensor parallelism for each module if needed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
|
||||||
@ -809,6 +812,12 @@ def _load_state_dict_into_meta_model(
|
|||||||
|
|
||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||||
|
|
||||||
|
# we need this later to initialize tensor parallelism
|
||||||
|
if device_mesh is not None:
|
||||||
|
full_tp_plan = model.config.base_model_tp_plan
|
||||||
|
for submodule in model.modules():
|
||||||
|
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||||
|
|
||||||
for param_name, param in state_dict.items():
|
for param_name, param in state_dict.items():
|
||||||
if param_name not in expected_keys:
|
if param_name not in expected_keys:
|
||||||
continue
|
continue
|
||||||
@ -912,6 +921,37 @@ def _load_state_dict_into_meta_model(
|
|||||||
setattr(module, tensor_name, value)
|
setattr(module, tensor_name, value)
|
||||||
# TODO: consider removing used param_parts from state_dict before return
|
# TODO: consider removing used param_parts from state_dict before return
|
||||||
|
|
||||||
|
# In this case, let's parallelize the modules!
|
||||||
|
if device_mesh is not None:
|
||||||
|
# Immediate parent
|
||||||
|
split_parent_module_name = param_name.split(".")[:-1]
|
||||||
|
parent_module_name = ".".join(split_parent_module_name)
|
||||||
|
parent_module = model
|
||||||
|
for name in split_parent_module_name:
|
||||||
|
parent_module = getattr(parent_module, name)
|
||||||
|
|
||||||
|
# Check if we are part of the tp_plan
|
||||||
|
current_module_plan = None
|
||||||
|
for param, plan in full_tp_plan.items():
|
||||||
|
# "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern
|
||||||
|
pattern = param.replace("*", "[0-9]+")
|
||||||
|
if re.search(pattern, parent_module_name):
|
||||||
|
current_module_plan = plan
|
||||||
|
break
|
||||||
|
|
||||||
|
# We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g.
|
||||||
|
# if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized)
|
||||||
|
process_device = list(device_map.values())[0]
|
||||||
|
all_module_parameters_initialized = all(
|
||||||
|
m.device == process_device for m in parent_module.parameters(recurse=False)
|
||||||
|
) and all(m.device == process_device for m in parent_module.buffers(recurse=False))
|
||||||
|
if current_module_plan is not None and all_module_parameters_initialized:
|
||||||
|
torch.distributed.tensor.parallel.parallelize_module(
|
||||||
|
parent_module,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
parallelize_plan=translate_to_torch_parallel_style(current_module_plan),
|
||||||
|
)
|
||||||
|
|
||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, offload_index, state_dict_index
|
||||||
|
|
||||||
|
|
||||||
@ -3489,12 +3529,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||||
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
|
# `device_map` pointing to the correct device
|
||||||
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
|
device_mesh = None
|
||||||
# And temporarily setting the default device to current process rank result in the following error
|
|
||||||
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
|
|
||||||
tp_device = None
|
|
||||||
if tp_plan is not None:
|
if tp_plan is not None:
|
||||||
|
if not is_torch_greater_or_equal("2.5"):
|
||||||
|
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||||
|
|
||||||
@ -3506,6 +3545,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# This is the easiest way to dispatch to the current process device
|
# This is the easiest way to dispatch to the current process device
|
||||||
device_map = tp_device
|
device_map = tp_device
|
||||||
|
|
||||||
|
# Assuming sharding the model onto the world
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||||
|
|
||||||
if is_fsdp_enabled():
|
if is_fsdp_enabled():
|
||||||
low_cpu_mem_usage = True
|
low_cpu_mem_usage = True
|
||||||
|
|
||||||
@ -3600,7 +3643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if low_cpu_mem_usage is None:
|
if low_cpu_mem_usage is None:
|
||||||
low_cpu_mem_usage = True
|
low_cpu_mem_usage = True
|
||||||
elif not low_cpu_mem_usage:
|
elif not low_cpu_mem_usage:
|
||||||
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`")
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
@ -3609,7 +3652,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
elif not is_accelerate_available():
|
elif not is_accelerate_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||||
@ -4186,6 +4229,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Let's make sure we don't run the init function of buffer modules
|
# Let's make sure we don't run the init function of buffer modules
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
if device_mesh is not None and not model.supports_tp_plan:
|
||||||
|
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||||
|
|
||||||
# make sure we use the model's config since the __init__ call might have copied it
|
# make sure we use the model's config since the __init__ call might have copied it
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
@ -4336,6 +4382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
gguf_path=gguf_path,
|
gguf_path=gguf_path,
|
||||||
weights_only=weights_only,
|
weights_only=weights_only,
|
||||||
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
@ -4370,8 +4417,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Dispatch model with hooks on all devices if necessary
|
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||||
if device_map is not None:
|
# harm performances)
|
||||||
|
if device_map is not None and device_mesh is None:
|
||||||
device_map_kwargs = {
|
device_map_kwargs = {
|
||||||
"device_map": device_map,
|
"device_map": device_map,
|
||||||
"offload_dir": offload_folder,
|
"offload_dir": offload_folder,
|
||||||
@ -4398,6 +4446,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||||
dispatch_model(model, **device_map_kwargs)
|
dispatch_model(model, **device_map_kwargs)
|
||||||
|
|
||||||
|
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||||
|
# not part of the state_dict (persistent=False)
|
||||||
|
if device_mesh is not None:
|
||||||
|
for buffer in model.buffers():
|
||||||
|
if buffer.device != tp_device:
|
||||||
|
buffer.data = buffer.to(tp_device)
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
hf_quantizer.postprocess_model(model, config=config)
|
hf_quantizer.postprocess_model(model, config=config)
|
||||||
model.hf_quantizer = hf_quantizer
|
model.hf_quantizer = hf_quantizer
|
||||||
@ -4420,16 +4475,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
}
|
}
|
||||||
return model, loading_info
|
return model, loading_info
|
||||||
|
|
||||||
if tp_plan is not None:
|
|
||||||
assert tp_device is not None, "tp_device not set!"
|
|
||||||
if not model.supports_tp_plan:
|
|
||||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
|
||||||
# Assuming sharding the model onto the world
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
|
||||||
# Apply Tensor Parallelism
|
|
||||||
model.tensor_parallel(device_mesh)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -4523,6 +4568,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
gguf_path=None,
|
gguf_path=None,
|
||||||
weights_only=True,
|
weights_only=True,
|
||||||
|
device_mesh=None,
|
||||||
):
|
):
|
||||||
is_safetensors = False
|
is_safetensors = False
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
@ -4822,6 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
is_safetensors=is_safetensors,
|
is_safetensors=is_safetensors,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
unexpected_keys=unexpected_keys,
|
unexpected_keys=unexpected_keys,
|
||||||
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
||||||
@ -4911,6 +4958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
is_safetensors=is_safetensors,
|
is_safetensors=is_safetensors,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
unexpected_keys=unexpected_keys,
|
unexpected_keys=unexpected_keys,
|
||||||
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
error_msgs += new_error_msgs
|
error_msgs += new_error_msgs
|
||||||
else:
|
else:
|
||||||
@ -5188,7 +5236,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
def tensor_parallel(self, device_mesh):
|
def tensor_parallel(self, device_mesh):
|
||||||
"""
|
"""
|
||||||
Tensor parallelize the model across the given device mesh.
|
Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model
|
||||||
|
was already loaded in memory, note however that this means that each process will first initialize the whole model,
|
||||||
|
then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time.
|
||||||
|
|
||||||
|
Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization,
|
||||||
|
so that the expected per-device memory spike at loading time is not larger than the final model size on each device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_mesh (`torch.distributed.DeviceMesh`):
|
device_mesh (`torch.distributed.DeviceMesh`):
|
||||||
|
@ -81,17 +81,13 @@ class TestTensorParallel(TestCasePlus):
|
|||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
# The expected full model memory footprint
|
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
|
||||||
expected_model_memory = 16
|
expected_model_memory_per_device = (16 / world_size) + 1
|
||||||
overhead_factor = 1.2
|
overhead_factor = 1.2
|
||||||
|
|
||||||
# Assert we did not use more than the full model expected memory (with some overhead)
|
# Check that we do not use more than the expected sharded size during initialization
|
||||||
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
|
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
|
||||||
raise ValueError("Loading the model used more than the full model size")
|
raise ValueError("Loading the model used more than the expected fraction of model size per device")
|
||||||
|
|
||||||
# Assert we correctly handled the sharding between devices
|
|
||||||
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
|
|
||||||
raise ValueError("Each model shard is larger than what is expected.")
|
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
Loading…
Reference in New Issue
Block a user