replace_8bit_linear modules_to_not_convert default value fix (#22238)

* Fixed modules_to_not_convert default value

* Fixed modules_to_not_convert docstring

* Update src/transformers/utils/bitsandbytes.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/utils/bitsandbytes.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* ["lm_head"] if modules_to_not_convert is None

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Andrei Panferov 2023-03-21 13:16:07 +03:00 committed by GitHub
parent c07a02a4b7
commit 330d8b991f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
module._parameters[tensor_name] = new_value
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", current_key_name=None):
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
@ -105,14 +105,15 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head",
threshold (`float`, *optional*, defaults to 6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []