mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 18:18:24 +06:00

* first commit * correct replace function * add final changes - works like charm! - cannot implement tests yet - tested * clean up a bit * add bitsandbytes dependencies * working version - added import function - added bitsandbytes utils file * small fix * small fix - fix import issue * fix import issues * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor a bit - move bitsandbytes utils to utils - change comments on functions * reformat docstring - reformat docstring on init_empty_weights_8bit * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * revert bad formatting * change to bitsandbytes * refactor a bit - remove init8bit since it is useless * more refactoring - fixed init empty weights issue - added threshold param * small hack to make it work * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * revmoe the small hack * modify utils file * make style + refactor a bit * create correctly device map * add correct dtype for device map creation * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply suggestions - remove with torch.grad - do not rely on Python bool magic! * add docstring - add docstring for new kwargs * add docstring - comment `replace_8bit_linear` function - fix weird formatting * - added more documentation - added new utility function for memory footprint tracking - colab demo to add * few modifs - typo doc - force cast into float16 when load_in_8bit is enabled * added colab link * add test architecture + docstring a bit * refactor a bit testing class * make style + refactor a bit * enhance checks - add more checks - start writing saving test * clean up a bit * male style * add more details on doc * add more tests - still needs to fix 2 tests * replace by "or" - could not fix it from GitHub GUI Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor a bit testing code + add readme * make style * fix import issue * Update src/transformers/modeling_utils.py Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * add few comments * add more doctring + make style * more docstring * raise error when loaded in 8bit * make style * add warning if loaded on CPU * add small sanity check * fix small comment * add bitsandbytes on dockerfile * Improve documentation - improve documentation from comments * add few comments * slow tests pass on the VM but not on the CI VM * Fix merge conflict * make style * another test should pass on a multi gpu setup * fix bad import in testing file * Fix slow tests - remove dummy batches - no more CUDA illegal memory errors * odify dockerfile * Update docs/source/en/main_classes/model.mdx * Update Dockerfile * Update model.mdx * Update Dockerfile * Apply suggestions from code review * few modifications - lm head can stay on disk/cpu - change model name so that test pass * change test value - change test value to the correct output - torch bmm changed to baddmm in bloom modeling when merging * modify installation guidelines * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * replace `n`by `name` * merge `load_in_8bit` and `low_cpu_mem_usage` * first try - keep the lm head in full precision * better check - check the attribute `base_model_prefix` instead of computing the number of parameters * added more tests * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit * improve documentation - fix typos for installation - change title in the documentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
204 lines
9.7 KiB
Plaintext
204 lines
9.7 KiB
Plaintext
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Models
|
|
|
|
The base classes [`PreTrainedModel`], [`TFPreTrainedModel`], and
|
|
[`FlaxPreTrainedModel`] implement the common methods for loading/saving a model either from a local
|
|
file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS
|
|
S3 repository).
|
|
|
|
[`PreTrainedModel`] and [`TFPreTrainedModel`] also implement a few methods which
|
|
are common among all the models to:
|
|
|
|
- resize the input token embeddings when new tokens are added to the vocabulary
|
|
- prune the attention heads of the model.
|
|
|
|
The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
|
|
(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
|
|
for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch models),
|
|
[`~generation_tf_utils.TFGenerationMixin`] (for the TensorFlow models) and
|
|
[`~generation_flax_utils.FlaxGenerationMixin`] (for the Flax/JAX models).
|
|
|
|
|
|
## PreTrainedModel
|
|
|
|
[[autodoc]] PreTrainedModel
|
|
- push_to_hub
|
|
- all
|
|
|
|
<a id='from_pretrained-torch-dtype'></a>
|
|
|
|
### Large model loading
|
|
|
|
In Transformers 4.20.0, the [`~PreTrainedModel.from_pretrained`] method has been reworked to accommodate large models using [Accelerate](https://huggingface.co/docs/accelerate/big_modeling). This requires Accelerate >= 0.9.0 and PyTorch >= 1.9.0. Instead of creating the full model, then loading the pretrained weights inside it (which takes twice the size of the model in RAM, one for the randomly initialized model, one for the weights), there is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded.
|
|
|
|
This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). This way the maximum RAM used is the full size of the model only.
|
|
|
|
```py
|
|
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", low_cpu_mem_usage=True)
|
|
```
|
|
|
|
Moreover, you can directly place the model on different devices if it doesn't fully fit in RAM (only works for inference for now). With `device_map="auto"`, Accelerate will determine where to put each layer to maximize the use of your fastest devices (GPUs) and offload the rest on the CPU, or even the hard drive if you don't have enough GPU RAM (or CPU RAM). Even if the model is split across several devices, it will run as you would normally expect.
|
|
|
|
When passing a `device_map`, `low_cpu_mem_usage` is automatically set to `True`, so you don't need to specify it:
|
|
|
|
```py
|
|
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
t0pp = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")
|
|
```
|
|
|
|
You can inspect how the model was split across devices by looking at its `hf_device_map` attribute:
|
|
|
|
```py
|
|
t0pp.hf_device_map
|
|
```
|
|
|
|
```python out
|
|
{'shared': 0,
|
|
'decoder.embed_tokens': 0,
|
|
'encoder': 0,
|
|
'decoder.block.0': 0,
|
|
'decoder.block.1': 1,
|
|
'decoder.block.2': 1,
|
|
'decoder.block.3': 1,
|
|
'decoder.block.4': 1,
|
|
'decoder.block.5': 1,
|
|
'decoder.block.6': 1,
|
|
'decoder.block.7': 1,
|
|
'decoder.block.8': 1,
|
|
'decoder.block.9': 1,
|
|
'decoder.block.10': 1,
|
|
'decoder.block.11': 1,
|
|
'decoder.block.12': 1,
|
|
'decoder.block.13': 1,
|
|
'decoder.block.14': 1,
|
|
'decoder.block.15': 1,
|
|
'decoder.block.16': 1,
|
|
'decoder.block.17': 1,
|
|
'decoder.block.18': 1,
|
|
'decoder.block.19': 1,
|
|
'decoder.block.20': 1,
|
|
'decoder.block.21': 1,
|
|
'decoder.block.22': 'cpu',
|
|
'decoder.block.23': 'cpu',
|
|
'decoder.final_layer_norm': 'cpu',
|
|
'decoder.dropout': 'cpu',
|
|
'lm_head': 'cpu'}
|
|
```
|
|
|
|
You can also write your own device map following the same format (a dictionary layer name to device). It should map all parameters of the model to a given device, but you don't have to detail where all the submosules of one layer go if that layer is entirely on the same device. For instance, the following device map would work properly for T0pp (as long as you have the GPU memory):
|
|
|
|
```python
|
|
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
|
|
```
|
|
|
|
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.
|
|
|
|
### Model Instantiation dtype
|
|
|
|
Under Pytorch a model normally gets instantiated with `torch.float32` format. This can be an issue if one tries to
|
|
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
|
|
either explicitly pass the desired `dtype` using `torch_dtype` argument:
|
|
|
|
```python
|
|
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
|
|
```
|
|
|
|
or, if you want the model to always load in the most optimal memory pattern, you can use the special value `"auto"`,
|
|
and then `dtype` will be automatically derived from the model's weights:
|
|
|
|
```python
|
|
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
|
|
```
|
|
|
|
Models instantiated from scratch can also be told which `dtype` to use with:
|
|
|
|
```python
|
|
config = T5Config.from_pretrained("t5")
|
|
model = AutoModel.from_config(config)
|
|
```
|
|
|
|
Due to Pytorch design, this functionality is only available for floating dtypes.
|
|
|
|
### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition
|
|
|
|
From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
|
|
For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision.
|
|
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.
|
|
|
|

|
|
|
|
Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).
|
|
Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).
|
|
|
|
Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.
|
|
|
|
Below are some notes to help you use this module, or follow this demo on Google colab: [](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)
|
|
|
|
#### Requirements
|
|
|
|
- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
|
|
- Install the correct version of `bitsandbytes` by running:
|
|
`pip install -i https://test.pypi.org/simple/ bitsandbytes`
|
|
- Install `accelerate`:
|
|
`pip install accelerate`
|
|
|
|
#### Running mixed-int8 models
|
|
|
|
After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
|
|
```py
|
|
model_name = "bigscience/bloom-2b5"
|
|
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
|
|
```
|
|
The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows:
|
|
(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`)
|
|
```py
|
|
max_memory_mapping = {0: "1GB", 1: "2GB"}
|
|
model_name = "bigscience/bloom-3b"
|
|
model_8bit = AutoModelForCausalLM.from_pretrained(
|
|
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
|
|
)
|
|
```
|
|
|
|
|
|
## ModuleUtilsMixin
|
|
|
|
[[autodoc]] modeling_utils.ModuleUtilsMixin
|
|
|
|
## TFPreTrainedModel
|
|
|
|
[[autodoc]] TFPreTrainedModel
|
|
- push_to_hub
|
|
- all
|
|
|
|
## TFModelUtilsMixin
|
|
|
|
[[autodoc]] modeling_tf_utils.TFModelUtilsMixin
|
|
|
|
## FlaxPreTrainedModel
|
|
|
|
[[autodoc]] FlaxPreTrainedModel
|
|
- push_to_hub
|
|
- all
|
|
|
|
## Pushing to the Hub
|
|
|
|
[[autodoc]] utils.PushToHubMixin
|
|
|
|
## Sharded checkpoints
|
|
|
|
[[autodoc]] modeling_utils.load_sharded_checkpoint
|