mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Name change AOPermod -> ModuleFqn (#38456)
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
parent
e8b292e35f
commit
279000bb70
@ -62,7 +62,7 @@ Install torchao from PyPi or the PyTorch index with the following commands.
|
|||||||
# Stable release from Pypi which will default to CUDA 12.6
|
# Stable release from Pypi which will default to CUDA 12.6
|
||||||
pip install --upgrade torchao transformers
|
pip install --upgrade torchao transformers
|
||||||
```
|
```
|
||||||
</hfoption>
|
</hfoption>
|
||||||
<hfoption id="PyTorch Index">
|
<hfoption id="PyTorch Index">
|
||||||
Stable Release from the PyTorch index
|
Stable Release from the PyTorch index
|
||||||
```bash
|
```bash
|
||||||
@ -276,18 +276,18 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
|||||||
|
|
||||||
### Per Module Quantization
|
### Per Module Quantization
|
||||||
#### 1. Skip quantization for certain layers
|
#### 1. Skip quantization for certain layers
|
||||||
With `AOPerModuleConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
|
|
||||||
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig
|
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
|
||||||
config = Int4WeightOnlyConfig(group_size=128)
|
config = Int4WeightOnlyConfig(group_size=128)
|
||||||
|
|
||||||
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
|
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
|
||||||
quant_config = AOPerModuleConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
|
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
|
||||||
@ -311,7 +311,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
|||||||
|
|
||||||
model_id = "facebook/opt-125m"
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
from torchao.quantization import Int4WeightOnlyConfig, AOPerModuleConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||||
|
|
||||||
weight_dtype = torch.int8
|
weight_dtype = torch.int8
|
||||||
granularity = PerAxis(0)
|
granularity = PerAxis(0)
|
||||||
@ -322,7 +322,7 @@ embedding_config = IntxWeightOnlyConfig(
|
|||||||
mapping_type=mapping_type,
|
mapping_type=mapping_type,
|
||||||
)
|
)
|
||||||
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
|
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
|
||||||
quant_config = AOPerModuleConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||||
# set `include_embedding` to True in order to include embedding in quantization
|
# set `include_embedding` to True in order to include embedding in quantization
|
||||||
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
|
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
|
||||||
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
|
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
|
||||||
@ -427,8 +427,8 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
|||||||
|
|
||||||
# reload the quantized model
|
# reload the quantized model
|
||||||
reloaded_model = AutoModelForCausalLM.from_pretrained(
|
reloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||||
output_dir,
|
output_dir,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||||
@ -463,8 +463,8 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
|||||||
|
|
||||||
# reload the quantized model
|
# reload the quantized model
|
||||||
reloaded_model = AutoModelForCausalLM.from_pretrained(
|
reloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||||
output_dir,
|
output_dir,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
@ -261,12 +261,12 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
|
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
|
||||||
|
|
||||||
# handle AOPerModuleConfig, introduced in torchao 0.11.0+
|
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
|
||||||
if self.quantization_config._get_ao_version() > version.Version("0.10.0"):
|
if self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
|
||||||
from torchao.quantization import AOPerModuleConfig
|
from torchao.quantization import ModuleFqnToConfig
|
||||||
|
|
||||||
config = self.quantization_config.get_apply_tensor_subclass()
|
config = self.quantization_config.get_apply_tensor_subclass()
|
||||||
if isinstance(config, AOPerModuleConfig):
|
if isinstance(config, ModuleFqnToConfig):
|
||||||
module_fqn, _ = param_name.rsplit(".", 1)
|
module_fqn, _ = param_name.rsplit(".", 1)
|
||||||
c = None
|
c = None
|
||||||
if module_fqn in config.module_fqn_to_config:
|
if module_fqn in config.module_fqn_to_config:
|
||||||
|
@ -43,10 +43,10 @@ if is_torchao_available():
|
|||||||
TensorCoreTiledLayout,
|
TensorCoreTiledLayout,
|
||||||
)
|
)
|
||||||
from torchao.quantization import (
|
from torchao.quantization import (
|
||||||
AOPerModuleConfig,
|
|
||||||
Int8WeightOnlyConfig,
|
Int8WeightOnlyConfig,
|
||||||
IntxWeightOnlyConfig,
|
IntxWeightOnlyConfig,
|
||||||
MappingType,
|
MappingType,
|
||||||
|
ModuleFqnToConfig,
|
||||||
PerAxis,
|
PerAxis,
|
||||||
)
|
)
|
||||||
from torchao.quantization.autoquant import AQMixin
|
from torchao.quantization.autoquant import AQMixin
|
||||||
@ -226,7 +226,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
granularity=granularity,
|
granularity=granularity,
|
||||||
mapping_type=mapping_type,
|
mapping_type=mapping_type,
|
||||||
)
|
)
|
||||||
config = AOPerModuleConfig(
|
config = ModuleFqnToConfig(
|
||||||
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
|
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
|
||||||
)
|
)
|
||||||
# need set `include_input_output_embeddings` to True
|
# need set `include_input_output_embeddings` to True
|
||||||
@ -253,7 +253,7 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
@require_torchao_version_greater_or_equal("0.11.0")
|
@require_torchao_version_greater_or_equal("0.11.0")
|
||||||
def test_per_module_config_skip(self):
|
def test_per_module_config_skip(self):
|
||||||
linear_config = Int8WeightOnlyConfig()
|
linear_config = Int8WeightOnlyConfig()
|
||||||
config = AOPerModuleConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None})
|
config = ModuleFqnToConfig({"_default": linear_config, "model.layers.0.self_attn.q_proj": None})
|
||||||
quant_config = TorchAoConfig(quant_type=config)
|
quant_config = TorchAoConfig(quant_type=config)
|
||||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user