[docs] Increase visibility of torch_dtype="auto" (#35067)

* auto-dtype

* feedback
This commit is contained in:
Steven Liu 2024-12-04 09:18:44 -08:00 committed by GitHub
parent baa3b22137
commit 1ed1de2fec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 49 additions and 35 deletions

View File

@ -138,12 +138,15 @@ Load a processor with [`AutoProcessor.from_pretrained`]:
<frameworkcontent>
<pt>
The `AutoModelFor` classes let you load a pretrained model for a given task (see [here](model_doc/auto) for a complete list of available tasks). For example, load a model for sequence classification with [`AutoModelForSequenceClassification.from_pretrained`]:
The `AutoModelFor` classes let you load a pretrained model for a given task (see [here](model_doc/auto) for a complete list of available tasks). For example, load a model for sequence classification with [`AutoModelForSequenceClassification.from_pretrained`].
> [!WARNING]
> By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased")
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype="auto")
```
Easily reuse the same checkpoint to load an architecture for a different task:
@ -151,7 +154,7 @@ Easily reuse the same checkpoint to load an architecture for a different task:
```py
>>> from transformers import AutoModelForTokenClassification
>>> model = AutoModelForTokenClassification.from_pretrained("distilbert/distilbert-base-uncased")
>>> model = AutoModelForTokenClassification.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype="auto")
```
<Tip warning={true}>

View File

@ -57,7 +57,7 @@ import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", torch_dtype="auto", device_map="auto")
model.generation_config.cache_implementation = "static"
@ -89,7 +89,7 @@ import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", torch_dtype="auto", device_map="auto")
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
@ -202,7 +202,7 @@ import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", torch_dtype="auto", device_map="auto")
model.generate = torch.compile(model.generate, mode="reduce-overhead", fullgraph=True)
input_text = "The theory of special relativity states "
@ -249,7 +249,7 @@ device, _, _ = get_backend() # automatically detects the underlying device type
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
inputs = tokenizer("Einstein's theory of relativity states", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b", torch_dtype="auto").to(device)
assistant_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
@ -271,7 +271,7 @@ device, _, _ = get_backend() # automatically detects the underlying device type
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
inputs = tokenizer("Einstein's theory of relativity states", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b", torch_dtype="auto").to(device)
assistant_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.7)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@ -300,7 +300,7 @@ device, _, _ = get_backend() # automatically detects the underlying device type
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
inputs = tokenizer("The second law of thermodynamics states", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b", torch_dtype="auto").to(device)
assistant_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").to(device)
outputs = model.generate(**inputs, prompt_lookup_num_tokens=3)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
@ -322,7 +322,7 @@ device, _, _ = get_backend() # automatically detects the underlying device type
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
inputs = tokenizer("The second law of thermodynamics states", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b", torch_dtype="auto").to(device)
outputs = model.generate(**inputs, prompt_lookup_num_tokens=3, do_sample=True, temperature=0.7)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
["The second law of thermodynamics states that energy cannot be created nor destroyed. It's not a"]

View File

@ -41,7 +41,7 @@ Enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] metho
```py
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder", torch_dtype="auto")
```
## TorchScript

View File

@ -405,7 +405,7 @@ To load a model in 4-bit for inference, use the `load_in_4bit` parameter. The `d
from transformers import AutoModelForCausalLM
model_name = "bigscience/bloom-2b5"
model_4bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", load_in_4bit=True)
```
To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 600MB of memory to the first GPU and 1GB of memory to the second GPU:
@ -414,7 +414,7 @@ To load a model in 4-bit for inference with multiple GPUs, you can control how m
max_memory_mapping = {0: "600MB", 1: "1GB"}
model_name = "bigscience/bloom-3b"
model_4bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
model_name, torch_dtype="auto", device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
)
```
@ -432,7 +432,7 @@ To load a model in 8-bit for inference, use the `load_in_8bit` parameter. The `d
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True))
```
If you're loading a model in 8-bit for text generation, you should use the [`~transformers.GenerationMixin.generate`] method instead of the [`Pipeline`] function which is not optimized for 8-bit models and will be slower. Some sampling strategies, like nucleus sampling, are also not supported by the [`Pipeline`] for 8-bit models. You should also place all inputs on the same device as the model:
@ -442,7 +442,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
model_name = "bigscience/bloom-2b5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True))
prompt = "Hello, my llama is cute"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
@ -456,7 +456,7 @@ To load a model in 4-bit for inference with multiple GPUs, you can control how m
max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
model_name, torch_dtype="auto", device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)
```
@ -515,7 +515,7 @@ quantization_config = BitsAndBytesConfig(
)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", quantization_config=quantization_config)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype="auto", quantization_config=quantization_config)
# enable BetterTransformer
model = model.to_bettertransformer()

View File

@ -59,10 +59,10 @@ Let's try the [Whisper large-v2](https://huggingface.co/openai/whisper-large-v2)
benchmarks. It also has the added benefit of predicting punctuation and casing, neither of which are possible with
Wav2Vec2.
Let's give it a try here to see how it performs:
Let's give it a try here to see how it performs. Set `torch_dtype="auto"` to automatically load the most memory-efficient data type the weights are stored in.
```py
>>> transcriber = pipeline(model="openai/whisper-large-v2")
>>> transcriber = pipeline(model="openai/whisper-large-v2", torch_dtype="auto")
>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
```

View File

@ -64,7 +64,7 @@ model_8bit = AutoModelForCausalLM.from_pretrained(
)
```
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want. Setting `torch_dtype="auto"` loads the model in the data type defined in a model's `config.json` file.
```py
import torch
@ -75,7 +75,7 @@ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
quantization_config=quantization_config,
torch_dtype=torch.float32
torch_dtype="auto"
)
model_8bit.model.decoder.layers[-1].final_layer_norm.weight.dtype
```
@ -112,7 +112,7 @@ model_4bit = AutoModelForCausalLM.from_pretrained(
)
```
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want. Setting `torch_dtype="auto"` loads the model in the data type defined in a model's `config.json` file.
```py
import torch
@ -123,7 +123,7 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
quantization_config=quantization_config,
torch_dtype=torch.float32
torch_dtype="auto"
)
model_4bit.model.decoder.layers[-1].final_layer_norm.weight.dtype
```
@ -190,6 +190,7 @@ Now load your model with the custom `device_map` and `quantization_config`:
```py
model_8bit = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-1b7",
torch_dtype="auto",
device_map=device_map,
quantization_config=quantization_config,
)
@ -212,6 +213,7 @@ quantization_config = BitsAndBytesConfig(
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map=device_map,
quantization_config=quantization_config,
)
@ -232,6 +234,7 @@ quantization_config = BitsAndBytesConfig(
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config,
)
@ -275,7 +278,7 @@ nf4_config = BitsAndBytesConfig(
bnb_4bit_quant_type="nf4",
)
model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", quantization_config=nf4_config)
```
For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.
@ -292,7 +295,7 @@ double_quant_config = BitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", torch_dtype="auto", quantization_config=double_quant_config)
```
## Dequantizing `bitsandbytes` models

View File

@ -33,13 +33,14 @@ pip install --upgrade accelerate fbgemm-gpu torch
If you are having issues with fbgemm-gpu and torch library, you might need to install the nightly release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch)
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FbgemmFp8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"

View File

@ -42,7 +42,9 @@ pip install optimum-quanto accelerate transformers
Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [optimum-quanto](https://github.com/huggingface/optimum-quanto) library instead.
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [optimum-quanto](https://github.com/huggingface/optimum-quanto) library instead.
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
@ -50,7 +52,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantization_config = QuantoConfig(weights="int8")
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0", quantization_config=quantization_config)
```
Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.

View File

@ -19,6 +19,7 @@ Before you begin, make sure the following libraries are installed with their lat
pip install --upgrade torch torchao
```
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
import torch
@ -28,7 +29,7 @@ model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"

View File

@ -245,13 +245,15 @@ Check out the [preprocess](./preprocessing) tutorial for more details about toke
<frameworkcontent>
<pt>
🤗 Transformers provides a simple and unified way to load pretrained instances. This means you can load an [`AutoModel`] like you would load an [`AutoTokenizer`]. The only difference is selecting the correct [`AutoModel`] for the task. For text (or sequence) classification, you should load [`AutoModelForSequenceClassification`]:
🤗 Transformers provides a simple and unified way to load pretrained instances. This means you can load an [`AutoModel`] like you would load an [`AutoTokenizer`]. The only difference is selecting the correct [`AutoModel`] for the task. For text (or sequence) classification, you should load [`AutoModelForSequenceClassification`].
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
>>> from transformers import AutoModelForSequenceClassification
>>> model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
>>> pt_model = AutoModelForSequenceClassification.from_pretrained(model_name)
>>> pt_model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype="auto")
```
<Tip>
@ -416,12 +418,12 @@ All models are a standard [`torch.nn.Module`](https://pytorch.org/docs/stable/nn
Depending on your task, you'll typically pass the following parameters to [`Trainer`]:
1. You'll start with a [`PreTrainedModel`] or a [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module):
1. You'll start with a [`PreTrainedModel`] or a [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module). Set `torch_dtype="auto"` to automatically load the most memory-efficient data type the weights are stored in.
```py
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased")
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype="auto")
```
2. [`TrainingArguments`] contains the model hyperparameters you can change like learning rate, batch size, and the number of epochs to train for. The default values are used if you don't specify any training arguments:

View File

@ -81,12 +81,14 @@ just use the button at the top-right of that framework's block!
🤗 Transformers provides a [`Trainer`] class optimized for training 🤗 Transformers models, making it easier to start training without manually writing your own training loop. The [`Trainer`] API supports a wide range of training options and features such as logging, gradient accumulation, and mixed precision.
Start by loading your model and specify the number of expected labels. From the Yelp Review [dataset card](https://huggingface.co/datasets/yelp_review_full#data-fields), you know there are five labels:
Start by loading your model and specify the number of expected labels. From the Yelp Review [dataset card](https://huggingface.co/datasets/yelp_review_full#data-fields), you know there are five labels.
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
```py
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5, torch_dtype="auto")
```
<Tip>