bitsandbytes - Linear8bitLt integration into transformers models (#17901)

* first commit

* correct replace function

* add final changes

- works like charm!
- cannot implement tests yet
- tested

* clean up a bit

* add bitsandbytes dependencies

* working version

- added import function
- added bitsandbytes utils file

* small fix

* small fix

- fix import issue

* fix import issues

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit

- move bitsandbytes utils to utils
- change comments on functions

* reformat docstring

- reformat docstring on init_empty_weights_8bit

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* revert bad formatting

* change to bitsandbytes

* refactor a bit

- remove init8bit since it is useless

* more refactoring

- fixed init empty weights issue
- added threshold param

* small hack to make it work

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* revmoe the small hack

* modify utils file

* make style + refactor a bit

* create correctly device map

* add correct dtype for device map creation

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

- remove with torch.grad
- do not rely on Python bool magic!

* add docstring

 - add docstring for new kwargs

* add docstring

- comment `replace_8bit_linear` function
- fix weird formatting

* - added more documentation
- added new utility function for memory footprint tracking
- colab demo to add

* few modifs

- typo doc
- force cast into float16 when load_in_8bit is enabled

* added colab link

* add test architecture + docstring a bit

* refactor a bit testing class

* make style + refactor a bit

* enhance checks

- add more checks
- start writing saving test

* clean up a bit

* male style

* add more details on doc

* add more tests

- still needs to fix 2 tests

* replace by "or"

- could not fix it from GitHub GUI

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit testing code + add readme

* make style

* fix import issue

* Update src/transformers/modeling_utils.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* add few comments

* add more doctring + make style

* more docstring

* raise error when loaded in 8bit

* make style

* add warning if loaded on CPU

* add small sanity check

* fix small comment

* add bitsandbytes on dockerfile

* Improve documentation

- improve documentation from comments

* add few comments

* slow tests pass on the VM but not on the CI VM

* Fix merge conflict

* make style

* another test should pass on a multi gpu setup

* fix bad import in testing file

* Fix slow tests

- remove dummy batches
- no more CUDA illegal memory errors

* odify dockerfile

* Update docs/source/en/main_classes/model.mdx

* Update Dockerfile

* Update model.mdx

* Update Dockerfile

* Apply suggestions from code review

* few modifications

- lm head can stay on disk/cpu
- change model name so that test pass

* change test value

- change test value to the correct output
- torch bmm changed to baddmm in bloom modeling when merging

* modify installation guidelines

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* replace `n`by `name`

* merge `load_in_8bit` and `low_cpu_mem_usage`

* first try - keep the lm head in full precision

* better check

- check the attribute `base_model_prefix` instead of computing the number of parameters

* added more tests

* Update src/transformers/utils/bitsandbytes.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit

* improve documentation

- fix typos for installation
- change title in the documentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
Younes Belkada 2022-08-10 09:13:36 +02:00 committed by GitHub
parent 8cf4a6f0a6
commit 4a51075a96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 534 additions and 8 deletions

View File

@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5
RUN python3 -m pip install --no-cache-dir decord
# When installing in editable mode, `transformers` is not recognized as a package.

View File

@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
```
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.
### Model Instantiation dtype
@ -133,6 +133,45 @@ model = AutoModel.from_config(config)
Due to Pytorch design, this functionality is only available for floating dtypes.
### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition
From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision.
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.
![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png)
Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).
Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).
Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.
Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)
#### Requirements
- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
- Install the correct version of `bitsandbytes` by running:
`pip install -i https://test.pypi.org/simple/ bitsandbytes`
- Install `accelerate`:
`pip install accelerate`
#### Running mixed-int8 models
After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
```py
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
```
The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows:
(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`)
```py
max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)
```
## ModuleUtilsMixin

View File

@ -462,6 +462,7 @@ _import_structure = {
"is_vision_available",
"logging",
],
"utils.bitsandbytes": [],
}
# sentencepiece-backed objects

View File

@ -61,6 +61,7 @@ from .utils import (
copy_func,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
logging,
replace_return_docstrings,
@ -83,6 +84,9 @@ if is_accelerate_available():
else:
get_balanced_memory = None
if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device
logger = logging.get_logger(__name__)
@ -501,6 +505,7 @@ def _load_state_dict_into_meta_model(
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@ -561,13 +566,14 @@ def _load_state_dict_into_meta_model(
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
else:
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
return error_msgs, offload_index, state_dict_index
@ -1578,6 +1584,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
@ -1707,6 +1731,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
load_in_8bit (`bool`, *optional*, defaults to `False`):
If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
install `bitsandbytes` compiled with your CUDA version by running `pip install -i
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
int8_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
values will be done in fp16. Values are usually normally distributed, that is, most values are in the
range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently
distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
@ -1796,7 +1836,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
@ -1804,7 +1846,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored."
)
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
@ -1824,6 +1865,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
if torch_dtype == "auto" or torch_dtype != torch.float16:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
torch_dtype = torch.float16
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)
from_pt = not (from_tf | from_flax)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@ -2063,12 +2126,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif low_cpu_mem_usage:
elif load_in_8bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if load_in_8bit:
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)
if isinstance(device_map, str):
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
@ -2091,9 +2161,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(
model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory
model,
no_split_module_classes=no_split_modules,
dtype=torch_dtype if not load_in_8bit else torch.int8,
max_memory=max_memory,
)
if load_in_8bit:
# The LM head can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
del device_map_without_lm_head
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
@ -2145,6 +2227,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
)
# make sure token embedding weights are still tied if needed
@ -2185,6 +2268,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_folder=None,
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
):
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@ -2250,7 +2334,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
@ -2359,6 +2446,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
)
error_msgs += new_error_msgs
else:

View File

@ -0,0 +1,142 @@
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
if is_bitsandbytes_available():
import torch
import torch.nn as nn
import bitsandbytes as bnb
if is_accelerate_available():
from accelerate import init_empty_weights
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
class `Int8Params` from `bitsandbytes`.
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
tensor_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
if is_buffer:
has_fp16_weights = None
else:
has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None)
if has_fp16_weights is not None:
param = module._parameters[tensor_name]
if param.device.type != "cuda":
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
if value.dtype == torch.int8:
raise ValueError(
"You cannot load weights that are saved in int8 using `load_in_8bit=True`, make sure you are",
" using `load_in_8bit=True` on float32/float16/bfloat16 weights.",
)
else:
new_value = torch.tensor(value, device="cpu")
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device)
module._parameters[tensor_name] = new_value
else:
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)
if is_buffer:
module._buffers[tensor_name] = new_value
else:
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert="lm_head"):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
predictive degradation is possible for very large models (>=176B parameters).
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*, defaults to 6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
Name of the module to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
"""
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)
if isinstance(module, nn.Linear) and name != modules_to_not_convert:
with init_empty_weights():
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
return model
def get_key_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Ignore this for base models (BertModel, GPT2Model, etc.)
if not hasattr(model, model.base_model_prefix):
return ""
# otherwise they have an attached head
list_modules = list(model.named_parameters())
last_name = list_modules[-1][0]
return last_name.split(".")[0]

View File

@ -0,0 +1,37 @@
# Testing mixed int8 quantization
## Hardware requirements
I am using a setup of 2 GPUs that are NVIDIA-Tesla T4 15GB
## Virutal envs
```conda create --name int8-testing python==3.8```
```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit```
```pip install -e ".[dev]"```
```pip install -i https://test.pypi.org/simple/ bitsandbytes```
```pip install git+https://github.com/huggingface/accelerate.git@e0212893ea6098cc0a7a3c7a6eb286a9104214c1```
## Trobleshooting
```conda create --name int8-testing python==3.8```
```pip install -i https://test.pypi.org/simple/ bitsandbytes```
```conda install pytorch torchvision torchaudio -c pytorch```
```git clone https://github.com/younesbelkada/transformers.git && git checkout integration-8bit```
```pip install -e ".[dev]"```
```pip install git+https://github.com/huggingface/accelerate.git@b52b793ea8bac108ba61192eead3cf11ca02433d```
### Check driver settings:
```
nvcc --version
```
```
ls -l $CONDA_PREFIX/lib/libcudart.so
```
### Recurrent bugs
Sometimes you have to run a "dummy" inference pass when dealing with a multi-GPU setup. Checkout the ```test_multi_gpu_loading``` and the ```test_pipeline``` functions.

View File

View File

@ -0,0 +1,215 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
if is_torch_available():
import torch
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class BaseMixedInt8Test(unittest.TestCase):
# We keep the constants inside the init function and model loading inside setUp function
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
# Therefore here we use only bloom-1b3 to test our module
model_name = "bigscience/bloom-1b7"
# Constant values
EXPECTED_RELATIVE_DIFFERENCE = (
1.540025 # This was obtained on a Quadro RTX 8000 so the number might slightly change
)
input_text = "Hello my name is"
EXPECTED_OUTPUT = "Hello my name is John.\nI am a friend of the family.\n"
MAX_NEW_TOKENS = 10
def setUp(self):
# Models and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
class MixedInt8Test(BaseMixedInt8Test):
def setUp(self):
super().setUp()
# Models and tokenizer
self.model_fp16 = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto")
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
del self.model_fp16
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
from bitsandbytes.nn import Int8Params
mem_fp16 = self.model_fp16.get_memory_footprint()
mem_8bit = self.model_8bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = self.model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
super().setUp()
# model_name
self.model_name = "bigscience/bloom-560m"
# Models and tokenizer
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, load_in_8bit=True, device_map="auto"
)
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
del self.base_model
del self.sequence_model
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
def test_correct_head_class(self):
r"""
A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
are kept in their native class.
"""
from bitsandbytes.nn import Int8Params
# last param of a base model should be a linear8bit module
self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
# Other heads should be nn.Parameter
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
class MixedInt8TestPipeline(BaseMixedInt8Test):
def setUp(self):
super().setUp()
def tearDown(self):
r"""
TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
"""
del self.pipe
gc.collect()
torch.cuda.empty_cache()
def test_pipeline(self):
r"""
The aim of this test is to verify that the mixed int8 is compatible with `pipeline` from transformers. Since
we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
on pipline.
"""
# self._clear_cuda_cache()
self.pipe = pipeline(
"text-generation",
model=self.model_name,
model_kwargs={"device_map": "auto", "load_in_8bit": True},
max_new_tokens=self.MAX_NEW_TOKENS,
)
# Real second forward pass
pipeline_output = self.pipe(self.input_text)
self.assertEqual(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
class MixedInt8TestMultiGpu(BaseMixedInt8Test):
def setUp(self):
super().setUp()
def test_multi_gpu_loading(self):
r"""
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
"""
memory_mapping = {0: "1GB", 1: "2GB"}
model_parallel = AutoModelForCausalLM.from_pretrained(
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
)
def get_list_devices(model):
list_devices = []
for _, module in model.named_children():
if len(list(module.children())) > 0:
list_devices.extend(get_list_devices(module))
else:
# Do a try except since we can encounter Dropout modules that does not
# have any device set
try:
list_devices.append(next(module.parameters()).device.index)
except BaseException:
continue
return list_devices
list_devices = get_list_devices(model_parallel)
# Check that we have dispatched the model into 2 separate devices
self.assertTrue((1 in list_devices) and (0 in list_devices))
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
# Second real batch
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

View File

@ -465,6 +465,7 @@ EXPECTED_TEST_FILES_NEVER_TOUCHED = [
"tests/sagemaker/test_single_node_gpu.py", # SageMaker test
"tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test
"tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test
"tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test
]