mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Fix big model inference for T5 models in float16 (#22095)
* Fix big model inference for T5 models in float16 * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Style * Trigger CI with latest release --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
7f5ad6c35b
commit
b45192ec47
@ -2547,6 +2547,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
) >= version.parse("0.37.0")
|
||||
|
||||
if isinstance(device_map, str):
|
||||
special_dtypes = {
|
||||
name: torch.float32
|
||||
for name, _ in model.named_parameters()
|
||||
if any(m in name for m in keep_in_fp32_modules)
|
||||
}
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
|
||||
no_split_modules = model._no_split_modules
|
||||
@ -2557,22 +2562,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
|
||||
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
|
||||
|
||||
kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory}
|
||||
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
||||
kwargs["special_dtypes"] = special_dtypes
|
||||
elif len(special_dtypes) > 0:
|
||||
logger.warn(
|
||||
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
||||
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
||||
)
|
||||
if device_map != "sequential" and get_balanced_memory is not None:
|
||||
max_memory = get_balanced_memory(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=no_split_modules,
|
||||
dtype=torch_dtype,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
**kwargs,
|
||||
)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
no_split_module_classes=no_split_modules,
|
||||
dtype=torch_dtype if not load_in_8bit else torch.int8,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)
|
||||
|
||||
if load_in_8bit:
|
||||
# The LM head / tied weights or any last module can stay on disk / CPU
|
||||
|
Loading…
Reference in New Issue
Block a user