mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Changing __repr__ in torchao to show quantized Linear (#34202)
* Changing __repr__ in torchao * small update * make style * small update * add LinearActivationQuantizedTensor * remove some cases * update imports & handle return None * update
This commit is contained in:
parent
f2d5dfbab2
commit
d2bae7ee9d
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from packaging import version
|
||||
@ -30,9 +31,7 @@ from ..utils import is_torch_available, is_torchao_available, logging
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -46,6 +45,25 @@ def find_parent(model, name):
|
||||
return parent
|
||||
|
||||
|
||||
def _quantization_type(weight):
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
|
||||
if isinstance(weight, AffineQuantizedTensor):
|
||||
return f"{weight.__class__.__name__}({weight._quantization_type()})"
|
||||
|
||||
if isinstance(weight, LinearActivationQuantizedTensor):
|
||||
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
|
||||
|
||||
|
||||
def _linear_extra_repr(self):
|
||||
weight = _quantization_type(self.weight)
|
||||
if weight is None:
|
||||
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
|
||||
else:
|
||||
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
|
||||
|
||||
|
||||
class TorchAoHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer for torchao: https://github.com/pytorch/ao/
|
||||
@ -152,9 +170,17 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
Each nn.Linear layer that needs to be quantized is processsed here.
|
||||
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
||||
"""
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
if self.pre_quantized:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
|
||||
if isinstance(module, nn.Linear):
|
||||
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
||||
else:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def _process_model_after_weight_loading(self, model):
|
||||
"""No process required for torchao quantized model"""
|
||||
|
Loading…
Reference in New Issue
Block a user