# coding=utf-8 # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available if is_torch_available(): import torch if is_accelerate_available(): from accelerate import init_empty_weights if is_quanto_available(): from quanto import QLayerNorm, QLinear from transformers.integrations.quanto import replace_with_quanto_layers class QuantoConfigTest(unittest.TestCase): def test_attributes(self): pass @require_quanto @require_accelerate class QuantoTestIntegration(unittest.TestCase): model_id = "facebook/opt-350m" def setUp(self): config = AutoConfig.from_pretrained(self.model_id) with init_empty_weights(): self.model = AutoModelForCausalLM.from_config(config) self.nb_linear = 0 self.nb_layernorm = 0 for module in self.model.modules(): if isinstance(module, torch.nn.Linear): self.nb_linear += 1 elif isinstance(module, torch.nn.LayerNorm): self.nb_layernorm += 1 def test_weight_only_quantization_conversion(self): """ Simple test that checks if the quantized model has been converted properly when using weight only quantization """ # Try with weight only quantization quantization_config = QuantoConfig(weights="int8", activations=None) self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config) nb_qlinear = 0 for module in self.model.modules(): if isinstance(module, QLinear): nb_qlinear += 1 self.assertEqual(self.nb_linear, nb_qlinear) def test_weight_and_activation_quantization_conversion(self): """ Simple test that checks if the quantized model has been converted properly when using weight + activation quantization """ # Try with weight + activation quantization quantization_config = QuantoConfig(weights="int8", activations="int8") self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config) nb_qlinear = 0 nb_qlayernorm = 0 for module in self.model.modules(): if isinstance(module, QLinear): nb_qlinear += 1 if isinstance(module, QLayerNorm): nb_qlayernorm += 1 self.assertEqual(self.nb_linear, nb_qlinear) self.assertEqual(self.nb_layernorm, nb_qlayernorm) def test_conversion_with_modules_to_not_convert(self): """ Simple test that checks if the quantized model has been converted properly when specifying modules_to_not_convert argument """ # Try with weight + activatioin quantization quantization_config = QuantoConfig(weights="int8", activations="int8") self.model, _ = replace_with_quanto_layers( self.model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"] ) nb_qlinear = 0 nb_qlayernorm = 0 for module in self.model.modules(): if isinstance(module, QLinear): nb_qlinear += 1 if isinstance(module, QLayerNorm): nb_qlayernorm += 1 self.assertEqual(self.nb_linear - 1, nb_qlinear) @slow @require_torch_gpu @require_quanto @require_accelerate class QuantoQuantizationTest(unittest.TestCase): """ Test 8-bit weights only quantization """ model_name = "bigscience/bloom-560m" weights = "int8" activations = None device_map = "cpu" input_text = "Hello my name is" EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer and I" def setUp(self): """ Setup quantized model """ quantization_config = QuantoConfig( weights=self.weights, activations=self.activations, ) self.quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=self.device_map, quantization_config=quantization_config, torch_dtype=torch.float32, ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.have_accelerate_hooks = ( getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1 ) def check_inference_correctness(self, model, device): r""" Test the generation quality of the quantized model and see that we are matching the expected output. Given that we are operating on small numbers + the testing model is relatively small, we might not get the same output across GPUs. So we'll generate few tokens (5-10) and check their output. """ if not self.have_accelerate_hooks: model.to(device) encoded_input = self.tokenizer(self.input_text, return_tensors="pt") output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) def test_generate_quality_cpu(self): """ Simple test to check the quality of the model on cpu by comparing the generated tokens with the expected tokens """ self.check_inference_correctness(self.quantized_model, "cpu") def test_generate_quality_cuda(self): """ Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens """ self.check_inference_correctness(self.quantized_model, "cuda") def test_quantized_model_layers(self): from quanto import QBitsTensor, QModuleMixin, QTensor """ Suite of simple test to check if the layers are quantized and are working properly """ # Test the type of the quantized layer self.assertTrue(isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value, QModuleMixin)) self.assertTrue( isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QTensor) ) if self.weights == "int4": self.assertTrue( isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QBitsTensor) ) # check that the lm_head was indeed not quantized, just like bnb self.assertTrue( isinstance(self.quantized_model.lm_head, torch.nn.Linear) and not isinstance(self.quantized_model.lm_head, QModuleMixin) ) if self.device_map in ["cpu", "cuda"]: self.assertEqual( self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, self.device_map, ) self.quantized_model.to(0) self.assertEqual( self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda" ) def test_serialization_bin(self): """ Test the serialization, the loading and the inference of the quantized weights """ with tempfile.TemporaryDirectory() as tmpdirname: with self.assertRaises(ValueError) as e: self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception)) # TODO: replace by the following when it works # quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( # tmpdirname, torch_dtype=torch.float32, device_map="cpu" # ) # self.check_inference_correctness(quantized_model_from_saved, device="cuda") def test_serialization_safetensors(self): """ Test the serialization, the loading and the inference of the quantized weights """ with tempfile.TemporaryDirectory() as tmpdirname: with self.assertRaises(ValueError) as e: self.quantized_model.save_pretrained(tmpdirname) self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception)) # quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( # tmpdirname, torch_dtype=torch.float32, device_map="cpu" # ) # self.check_inference_correctness(quantized_model_from_saved, device="cuda") def check_same_model(self, model1, model2): d0 = dict(model1.named_parameters()) d1 = dict(model2.named_parameters()) self.assertTrue(d0.keys() == d1.keys()) for k in d0.keys(): self.assertTrue(d0[k].shape == d1[k].shape) self.assertTrue(d0[k].device.type == d1[k].device.type) self.assertTrue(d0[k].device == d1[k].device) self.assertTrue(d0[k].dtype == d1[k].dtype) self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) def test_compare_with_quanto(self): from quanto import freeze, qint4, qint8, quantize w_mapping = {"int8": qint8, "int4": qint4} model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=self.device_map, torch_dtype=torch.float32, ) # we do not quantize the lm_head since we don't do that in transformers quantize(model.transformer, weights=w_mapping[self.weights]) freeze(model.transformer) self.check_same_model(model, self.quantized_model) self.check_inference_correctness(model, device="cuda") @unittest.skip def test_load_from_quanto_saved(self): from quanto import freeze, qint4, qint8, quantize from transformers import QuantoConfig w_mapping = {"int8": qint8, "int4": qint4} model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=self.device_map, torch_dtype=torch.float32, ) # we do not quantize the lm_head since we don't do that in transformers quantize(model.transformer, weights=w_mapping[self.weights]) freeze(model.transformer) with tempfile.TemporaryDirectory() as tmpdirname: model.config.quantization_config = QuantoConfig( weights=self.weights, activations=self.activations, modules_to_not_convert=["lm_head"] ) model.save_pretrained(tmpdirname, safe_serialization=False) quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( tmpdirname, device_map=self.device_map, torch_dtype=torch.float32, ) self.check_same_model(model, quantized_model_from_saved) self.check_inference_correctness(quantized_model_from_saved, device="cuda") class QuantoQuantizationOffloadTest(QuantoQuantizationTest): device_map = { "transformer.word_embeddings": 0, "transformer.word_embeddings_layernorm": 0, "transformer.ln_f": 0, "transformer.h.0": 0, "transformer.h.1": 0, "transformer.h.2": 0, "transformer.h.3": 0, "transformer.h.4": 0, "transformer.h.5": 0, "transformer.h.6": 0, "transformer.h.7": 0, "transformer.h.8": 0, "transformer.h.9": 0, "transformer.h.10": 0, "transformer.h.11": 0, "transformer.h.12": 0, "transformer.h.13": 0, "transformer.h.14": 0, "transformer.h.15": 0, "transformer.h.16": 0, "transformer.h.17": 0, "transformer.h.18": 0, "transformer.h.19": 0, "transformer.h.20": 0, "transformer.h.21": 0, "transformer.h.22": "cpu", "transformer.h.23": "disk", "lm_head": 0, } # the execution device is a gpu def test_generate_quality_cpu(self): pass # we can't save offloaded values def test_serialization_bin(self): pass def test_serialization_safetensors(self): pass def test_compare_with_quanto(self): pass def test_load_from_quanto_saved(self): pass def test_check_offload_quantized(self): """ We check that we have unquantized value in the cpu and in the disk """ import quanto cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[ "weight" ] disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[ "weight" ] self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor)) self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor)) if self.weights == "int4": self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)) self.assertTrue( isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor) ) @unittest.skip("Skipping test class because serialization is not supported yet") class QuantoQuantizationSerializationTest(QuantoQuantizationTest): """ Perform the same tests as in QuantoQuantizationTest but with a serialized model. """ def setUp(self): """ Setup quantized model """ quantization_config = QuantoConfig( weights=self.weights, activations=self.activations, ) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map=self.device_map, quantization_config=quantization_config, torch_dtype=torch.float32, ) with tempfile.TemporaryDirectory() as tmpdirname: quantized_model.save_pretrained(tmpdirname, safe_serialization=False) self.quantized_model = AutoModelForCausalLM.from_pretrained( tmpdirname, torch_dtype=torch.float32, device_map=self.device_map ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.have_accelerate_hooks = ( getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1 ) @unittest.skip("Skipping test class because serialization is not supported yet") class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest): """ Perform the same tests as in QuantoQuantizationTest but with model on cuda """ device_map = "cuda:0" class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest): EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines" weights = "int4" class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest): EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines" weights = "int4" @unittest.skip("Skipping test class because serialization is not supported yet") class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest): EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines" weights = "int4" @require_torch_gpu class QuantoQuantizationActivationTest(unittest.TestCase): def test_quantize_activation(self): quantization_config = QuantoConfig( weights="int8", activations="int8", ) with self.assertRaises(ValueError) as e: AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config) self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))