Changing __repr__ in torchao to show quantized Linear (#34202)

* Changing __repr__ in torchao

* small update

* make style

* small update

* add LinearActivationQuantizedTensor

* remove some cases

* update imports & handle return None

* update
This commit is contained in:
Mohamed Mekkouri 2024-11-05 16:11:02 +01:00 committed by GitHub
parent f2d5dfbab2
commit d2bae7ee9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import types
from typing import TYPE_CHECKING, Union
from packaging import version
@ -30,9 +31,7 @@ from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available():
import torch
if is_torchao_available():
from torchao.quantization import quantize_
import torch.nn as nn
logger = logging.get_logger(__name__)
@ -46,6 +45,25 @@ def find_parent(model, name):
return parent
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"
if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
@ -152,9 +170,17 @@ class TorchAoHfQuantizer(HfQuantizer):
Each nn.Linear layer that needs to be quantized is processsed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
from torchao.quantization import quantize_
module, tensor_name = get_module_from_name(model, param_name)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""