From 330d8b991fee414e75098baa76b33b408d4a8b53 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Tue, 21 Mar 2023 13:16:07 +0300 Subject: [PATCH] replace_8bit_linear modules_to_not_convert default value fix (#22238) * Fixed modules_to_not_convert default value * Fixed modules_to_not_convert docstring * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/utils/bitsandbytes.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * ["lm_head"] if modules_to_not_convert is None --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/bitsandbytes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index efd9abf6ce5..aee3fec9744 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None): module._parameters[tensor_name] = new_value -def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", current_key_name=None): +def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): @@ -105,14 +105,15 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", threshold (`float`, *optional*, defaults to 6.0): `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to `6.0` as described by the paper. - modules_to_not_convert (`str`, *optional*, defaults to `lm_head`): - Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision for numerical stability reasons. current_key_name (`List[`str`]`, *optional*): An array to track the current key of the recursion. This is used to check whether the current key (part of it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or `disk`). """ + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert for name, module in model.named_children(): if current_key_name is None: current_key_name = []