mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
replace_8bit_linear modules_to_not_convert default value fix (#22238)
* Fixed modules_to_not_convert default value * Fixed modules_to_not_convert docstring * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/utils/bitsandbytes.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * ["lm_head"] if modules_to_not_convert is None --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
c07a02a4b7
commit
330d8b991f
@ -84,7 +84,7 @@ def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
|
||||
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head", current_key_name=None):
|
||||
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
||||
@ -105,14 +105,15 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head",
|
||||
threshold (`float`, *optional*, defaults to 6.0):
|
||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
||||
`6.0` as described by the paper.
|
||||
modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
|
||||
Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`List[`str`]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
||||
`disk`).
|
||||
"""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
Loading…
Reference in New Issue
Block a user