Fix big model inference for T5 models in float16 (#22095)

* Fix big model inference for T5 models in float16

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Style

* Trigger CI with latest release

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2023-03-14 09:20:16 -04:00 committed by GitHub
parent 7f5ad6c35b
commit b45192ec47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2547,6 +2547,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) >= version.parse("0.37.0")
if isinstance(device_map, str):
special_dtypes = {
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules
@ -2557,22 +2562,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
logger.warn(
"This model has some weights that should be kept in higher precision, you need to upgrade "
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
)
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_modules,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
**kwargs,
)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=no_split_modules,
dtype=torch_dtype if not load_in_8bit else torch.int8,
max_memory=max_memory,
)
device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)
if load_in_8bit:
# The LM head / tied weights or any last module can stay on disk / CPU