transformers/docs/source/en/perf_infer_gpu_one.md
tomeras91 3f20877da9
Add jamba (#29943)
* Add jamba arch

* apply "make fix-copies" changes

* fix link to model in JambaConfig docstring

* Add n_ctx in modeling file because repo-consistency wants that

* Add jamba to flash attention and sdpa documentation

* mamba dt_proj quant fix now works for LoRA as well

* override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers

* add jamba to tokenization auto

* fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24)

* simple PR fixes

* remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer

* remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530)

* Add copied comment on JambaMLP (it's the same as MixtralMLP)

* remove padding_mask warnings. It's not supported anymore

* fix docstring. Float instead of int

* A few more minor PR fixes

* (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass

* Return None attention weights from mamba layers. Append to all attentions only if not None.

* remove some leftover jamba archive lists

* Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel

* no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers

* Add Jamba paper on READMEs

* (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes)

* Add copied from comment

* remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms

* clearer docstring for _convert_to_standard_cache

* style fixes

* Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs

* rename test so it still overrides what its meant to override

* draft

* oups

* nit

* remove more complexe logic

* fix names used in config

* fix fix fix

* style

* fix some more failing tests

* generate did not init the cache 🙃

* more small nits

* typo

* config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes

* fix init of pkv with torch.tensor()

* empty tensor

* fix some init issues

* stupid changes required by generate because it does not even support it's own DynamicCache class

* more fixes

* fix general assisted gen cache_position bug

* tests passing

* Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py

* fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache

* no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore

* fix docstrings and typehints for past_key_values

* style fixes

* fix docs

* change typehint due to copy from Mixtral

* forgot import

* import order

* Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward)

* Add integration test with tiny tandom Jamba model on hub

* fix flash attention cache shapes

* bring back forgotten hidden states

* rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model

* align integration test after modeling fixes

* bugfix - mamba can use precomputed states only of forward pass is on a single token

* bugfix - mamba can use precomputed states only if they match the batch size

* typo

* remove making _prepare_4d_causal_attention_mask a leaf function

* stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
2024-04-18 11:04:02 +02:00

23 KiB

GPU inference

GPUs are the standard choice of hardware for machine learning, unlike CPUs, because they are optimized for memory bandwidth and parallelism. To keep up with the larger sizes of modern models or to run these large models on existing and older hardware, there are several optimizations you can use to speed up GPU inference. In this guide, you'll learn how to use FlashAttention-2 (a more memory-efficient attention mechanism), BetterTransformer (a PyTorch native fastpath execution), and bitsandbytes to quantize your model to a lower precision. Finally, learn how to use 🤗 Optimum to accelerate inference with ONNX Runtime on Nvidia and AMD GPUs.

The majority of the optimizations described here also apply to multi-GPU setups!

FlashAttention-2

FlashAttention-2 is experimental and may change considerably in future versions.

FlashAttention-2 is a faster and more efficient implementation of the standard attention mechanism that can significantly speedup inference by:

  1. additionally parallelizing the attention computation over sequence length
  2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 is currently supported for the following architectures:

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-2 installed.

pip install flash-attn --no-build-isolation

We strongly suggest referring to the detailed installation instructions to learn more about supported hardware and data types!

FlashAttention-2 is also supported on AMD GPUs and current support is limited to Instinct MI210 and Instinct MI250. We strongly suggest using this Dockerfile to use FlashAttention-2 on AMD GPUs.

To enable FlashAttention-2, pass the argument attn_implementation="flash_attention_2" to [~AutoModelForCausalLM.from_pretrained]:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

FlashAttention-2 can only be used when the model's dtype is fp16 or bf16. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.


You can also set use_flash_attention_2=True to enable FlashAttention-2 but it is deprecated in favor of attn_implementation="flash_attention_2".

FlashAttention-2 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-2 with 8-bit or 4-bit quantization:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# load in 8bit
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    attn_implementation="flash_attention_2",
)

# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    attn_implementation="flash_attention_2",
)

Expected speedups

You can benefit from considerable speedups for inference, especially for inputs with long sequences. However, since FlashAttention-2 does not support computing attention scores with padding tokens, you must manually pad/unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.

To overcome this, you should use FlashAttention-2 without padding tokens in the sequence during training (by packing a dataset or concatenating sequences until reaching the maximum sequence length).

For a single forward pass on tiiuae/falcon-7b with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:

For a single forward pass on meta-llama/Llama-7b-hf with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:

For sequences with padding tokens (generating with padding tokens), you need to unpad/pad the input sequences to correctly compute the attention scores. With a relatively small sequence length, a single forward pass creates overhead leading to a small speedup (in the example below, 30% of the input is filled with padding tokens):

But for larger sequence lengths, you can expect even more speedup benefits:

FlashAttention is more memory efficient, meaning you can train on much larger sequence lengths without running into out-of-memory issues. You can potentially reduce memory usage up to 20x for larger sequence lengths. Take a look at the flash-attention repository for more details.

PyTorch scaled dot product attention

PyTorch's torch.nn.functional.scaled_dot_product_attention (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for torch>=2.1.1 when an implementation is available.

For now, Transformers supports SDPA inference and training for the following architectures:

FlashAttention can only be used for models with the fp16 or bf16 torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle fp32 models.

By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with torch.backends.cuda.sdp_kernel as a context manager:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()

input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

If you see a bug with the traceback below, try using the nightly version of PyTorch which may have broader coverage for FlashAttention:

RuntimeError: No available kernel. Aborting execution.

# install PyTorch nightly
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118

BetterTransformer

Some BetterTransformer features are being upstreamed to Transformers with default support for native torch.nn.scaled_dot_product_attention. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to natively support SDPA in Transformers.

Check out our benchmarks with BetterTransformer and scaled dot product attention in the Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0 and learn more about the fastpath execution in the BetterTransformer blog post.

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

  1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
  2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient scaled dot product attention (SDPA), and it calls optimized kernels like FlashAttention under the hood.

Before you start, make sure you have 🤗 Optimum installed.

Then you can enable BetterTransformer with the [PreTrainedModel.to_bettertransformer] method:

model = model.to_bettertransformer()

You can return the original Transformers model with the [~PreTrainedModel.reverse_bettertransformer] method. You should use this before saving your model to use the canonical Transformers modeling:

model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")

bitsandbytes

bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.

Make sure you have bitsandbytes and 🤗 Accelerate installed:

# these versions support 8-bit and 4-bit
pip install bitsandbytes>=0.39.0 accelerate>=0.20.0

# install Transformers
pip install transformers

4-bit

To load a model in 4-bit for inference, use the load_in_4bit parameter. The device_map parameter is optional, but we recommend setting it to "auto" to allow 🤗 Accelerate to automatically and efficiently allocate the model given the available resources in the environment.

from transformers import AutoModelForCausalLM

model_name = "bigscience/bloom-2b5"
model_4bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)

To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 600MB of memory to the first GPU and 1GB of memory to the second GPU:

max_memory_mapping = {0: "600MB", 1: "1GB"}
model_name = "bigscience/bloom-3b"
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
)

8-bit

If you're curious and interested in learning more about the concepts underlying 8-bit quantization, read the Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes blog post.

To load a model in 8-bit for inference, use the load_in_8bit parameter. The device_map parameter is optional, but we recommend setting it to "auto" to allow 🤗 Accelerate to automatically and efficiently allocate the model given the available resources in the environment:

from transformers import AutoModelForCausalLM

model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)

If you're loading a model in 8-bit for text generation, you should use the [~transformers.GenerationMixin.generate] method instead of the [Pipeline] function which is not optimized for 8-bit models and will be slower. Some sampling strategies, like nucleus sampling, are also not supported by the [Pipeline] for 8-bit models. You should also place all inputs on the same device as the model:

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "bigscience/bloom-2b5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)

prompt = "Hello, my llama is cute"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = model.generate(**inputs)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 1GB of memory to the first GPU and 2GB of memory to the second GPU:

max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)

Feel free to try running a 11 billion parameter T5 model or the 3 billion parameter BLOOM model for inference on Google Colab's free tier GPUs!

🤗 Optimum

Learn more details about using ORT with 🤗 Optimum in the Accelerated inference on NVIDIA GPUs and Accelerated inference on AMD GPUs guides. This section only provides a brief and simple example.

ONNX Runtime (ORT) is a model accelerator that supports accelerated inference on Nvidia GPUs, and AMD GPUs that use ROCm stack. ORT uses optimization techniques like fusing common operations into a single node and constant folding to reduce the number of computations performed and speedup inference. ORT also places the most computationally intensive operations on the GPU and the rest on the CPU to intelligently distribute the workload between the two devices.

ORT is supported by 🤗 Optimum which can be used in 🤗 Transformers. You'll need to use an [~optimum.onnxruntime.ORTModel] for the task you're solving, and specify the provider parameter which can be set to either CUDAExecutionProvider, ROCMExecutionProvider or TensorrtExecutionProvider. If you want to load a model that was not yet exported to ONNX, you can set export=True to convert your model on-the-fly to the ONNX format:

from optimum.onnxruntime import ORTModelForSequenceClassification

ort_model = ORTModelForSequenceClassification.from_pretrained(
  "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
  export=True,
  provider="CUDAExecutionProvider",
)

Now you're free to use the model for inference:

from optimum.pipelines import pipeline
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")

pipeline = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
result = pipeline("Both the music and visual were astounding, not to mention the actors performance.")

Combine optimizations

It is often possible to combine several of the optimization techniques described above to get the best inference performance possible for your model. For example, you can load a model in 4-bit, and then enable BetterTransformer with FlashAttention:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# load model in 4-bit
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", quantization_config=quantization_config)

# enable BetterTransformer
model = model.to_bettertransformer()

input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

# enable FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))