mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Run model as compressed/uncompressed mode (#34719)
* draft, run model as compreszed/uncompressed mode * draft * run run_compressed=False * run_compressed as attr * set run_compressed=False using quantization_config * remove redundant line * make is_qat_trainable dependent on run_compressed status * add tests * lint * full in docstring * add decompress * comments * decompress if model is compresssed and not run_compressed * apply_quant_config logic fix -- populate statedict properly * comments * remove non compressed model * make is_compressed as property * cosmetic * run apply_quant_config for non-compressed models -- popualte scales and zeropoints * add pahtway for decompressing sparse models * typo on is_quantization_compressed * lint * fix typo
This commit is contained in:
parent
31f9a289a6
commit
e4e404fdd0
@ -3597,7 +3597,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config.quantization_config = quantization_config
|
config.quantization_config = quantization_config
|
||||||
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
|
|
||||||
|
hf_quantizer = AutoHfQuantizer.from_config(
|
||||||
|
config.quantization_config,
|
||||||
|
pre_quantized=pre_quantized,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hf_quantizer = None
|
hf_quantizer = None
|
||||||
|
|
||||||
@ -4281,7 +4286,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
dispatch_model(model, **device_map_kwargs)
|
dispatch_model(model, **device_map_kwargs)
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
hf_quantizer.postprocess_model(model)
|
hf_quantizer.postprocess_model(model, config=config)
|
||||||
model.hf_quantizer = hf_quantizer
|
model.hf_quantizer = hf_quantizer
|
||||||
|
|
||||||
if _adapter_model_path is not None:
|
if _adapter_model_path is not None:
|
||||||
|
@ -173,13 +173,14 @@ class AutoHfQuantizer:
|
|||||||
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
|
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
|
||||||
and quantization_config_from_args is not None
|
and quantization_config_from_args is not None
|
||||||
):
|
):
|
||||||
# special case for GPTQ / AWQ / FbgemmFp8 config collision
|
# special case for GPTQ / AWQ / FbgemmFp8 config collision
|
||||||
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
|
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
|
||||||
for attr, val in loading_attr_dict.items():
|
for attr, val in loading_attr_dict.items():
|
||||||
setattr(quantization_config, attr, val)
|
setattr(quantization_config, attr, val)
|
||||||
|
|
||||||
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
|
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
|
||||||
|
|
||||||
if warning_msg != "":
|
if warning_msg != "":
|
||||||
|
@ -111,7 +111,7 @@ class AwqQuantizer(HfQuantizer):
|
|||||||
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
|
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model):
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
if self.quantization_config.do_fuse:
|
if self.quantization_config.do_fuse:
|
||||||
from ..integrations import fuse_awq_modules
|
from ..integrations import fuse_awq_modules
|
||||||
|
|
||||||
|
@ -12,8 +12,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from ..utils import is_compressed_tensors_available, is_torch_available, logging
|
from ..utils import is_compressed_tensors_available, is_torch_available, logging
|
||||||
from ..utils.quantization_config import QuantizationConfigMixin
|
from ..utils.quantization_config import CompressedTensorsConfig
|
||||||
from .base import HfQuantizer
|
from .base import HfQuantizer
|
||||||
|
|
||||||
|
|
||||||
@ -32,12 +35,13 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
|||||||
requires_calibration = True
|
requires_calibration = True
|
||||||
required_packages = ["compressed_tensors"]
|
required_packages = ["compressed_tensors"]
|
||||||
|
|
||||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
|
||||||
super().__init__(quantization_config, **kwargs)
|
super().__init__(quantization_config, **kwargs)
|
||||||
|
|
||||||
from compressed_tensors.compressors import ModelCompressor
|
from compressed_tensors.compressors import ModelCompressor
|
||||||
|
|
||||||
self.compressor = ModelCompressor.from_compression_config(quantization_config)
|
self.compressor = ModelCompressor.from_compression_config(quantization_config)
|
||||||
|
self.run_compressed = quantization_config.run_compressed
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
def validate_environment(self, *args, **kwargs):
|
def validate_environment(self, *args, **kwargs):
|
||||||
if not is_compressed_tensors_available():
|
if not is_compressed_tensors_available():
|
||||||
@ -63,20 +67,57 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
|||||||
from compressed_tensors.quantization import apply_quantization_config
|
from compressed_tensors.quantization import apply_quantization_config
|
||||||
|
|
||||||
ct_quantization_config = self.compressor.quantization_config
|
ct_quantization_config = self.compressor.quantization_config
|
||||||
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
|
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model, **kwargs) -> None:
|
if self.run_compressed and self.is_quantization_compressed:
|
||||||
pass
|
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
|
||||||
|
elif not self.is_quantization_compressed:
|
||||||
|
apply_quantization_config(model, ct_quantization_config)
|
||||||
|
|
||||||
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
|
"""Decompress loaded model if necessary - need for qat"""
|
||||||
|
|
||||||
|
if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed:
|
||||||
|
config = kwargs.get("config", None)
|
||||||
|
cache_path = config._name_or_path
|
||||||
|
|
||||||
|
if not os.path.exists(cache_path):
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
config_file_path = cached_file(cache_path, "config.json")
|
||||||
|
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
|
||||||
|
|
||||||
|
if self.is_quantization_compressed and not self.run_compressed:
|
||||||
|
from compressed_tensors.quantization import QuantizationStatus
|
||||||
|
|
||||||
|
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
|
||||||
|
self.compressor.decompress(model_path=cache_path, model=model)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_quantization_compressed(self):
|
||||||
"""Models quantized using compressed tensors can be finetuned"""
|
from compressed_tensors.quantization import QuantizationStatus
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.quantization_config.quantization_config is not None
|
||||||
|
and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_sparsification_compressed(self):
|
||||||
|
from compressed_tensors.config.base import CompressionFormat
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.quantization_config.sparsity_config is not None
|
||||||
|
and self.quantization_config.sparsity_config.format != CompressionFormat.dense.value
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trainable(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def is_qat_trainable(self) -> bool:
|
def is_qat_trainable(self) -> bool:
|
||||||
"""Loaded Models can carry out quantization aware training"""
|
"""Loaded Models can carry out quantization aware training"""
|
||||||
return True
|
# models need to be decompressed carry out qat
|
||||||
|
return not self.run_compressed or not self.is_quantization_compressed
|
||||||
|
|
||||||
def is_serializable(self, safe_serialization=None) -> bool:
|
def is_serializable(self, safe_serialization=None) -> bool:
|
||||||
"""Models quantized using compressed tensors can be saved to disk"""
|
"""Models quantized using compressed tensors can be saved to disk"""
|
||||||
|
@ -197,7 +197,7 @@ class QuantoHfQuantizer(HfQuantizer):
|
|||||||
)
|
)
|
||||||
model.config.quantization_config = self.quantization_config
|
model.config.quantization_config = self.quantization_config
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model):
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -195,7 +195,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model):
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
"""No process required for torchao quantized model"""
|
"""No process required for torchao quantized model"""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1077,7 +1077,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
|
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
|
||||||
dictionary mapping group name to a quantization scheme definition
|
dictionary mapping group name to a quantization scheme definition
|
||||||
format (`str`, *optional*, defaults to `"dense"`):
|
format (`str`, *optional*, defaults to `"dense"`):
|
||||||
format the model is represented as
|
format the model is represented as. Set `run_compressed` True to execute model as the
|
||||||
|
compressed format if not `dense`
|
||||||
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
|
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
|
||||||
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
|
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
|
||||||
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
|
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
|
||||||
@ -1090,6 +1091,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
configuration for sparsity compression
|
configuration for sparsity compression
|
||||||
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
|
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
|
||||||
do not override, should be compressed-tensors
|
do not override, should be compressed-tensors
|
||||||
|
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
|
||||||
|
emulate compressed model execution if True, otherwise use default submodule
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1102,14 +1105,17 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
ignore: Optional[List[str]] = None,
|
ignore: Optional[List[str]] = None,
|
||||||
sparsity_config: Dict[str, Any] = None,
|
sparsity_config: Dict[str, Any] = None,
|
||||||
quant_method: str = "compressed-tensors",
|
quant_method: str = "compressed-tensors",
|
||||||
|
run_compressed: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from compressed_tensors import QuantizationConfig
|
|
||||||
from compressed_tensors.config import SparsityCompressionConfig
|
from compressed_tensors.config import SparsityCompressionConfig
|
||||||
|
from compressed_tensors.quantization import QuantizationConfig
|
||||||
|
|
||||||
self.quantization_config = None
|
self.quantization_config = None
|
||||||
self.sparsity_config = None
|
self.sparsity_config = None
|
||||||
|
|
||||||
|
self.run_compressed = run_compressed
|
||||||
|
|
||||||
# parse from dict to load nested QuantizationScheme objects
|
# parse from dict to load nested QuantizationScheme objects
|
||||||
if config_groups or kv_cache_scheme:
|
if config_groups or kv_cache_scheme:
|
||||||
self.quantization_config = QuantizationConfig.parse_obj(
|
self.quantization_config = QuantizationConfig.parse_obj(
|
||||||
@ -1121,6 +1127,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
"kv_cache_scheme": kv_cache_scheme,
|
"kv_cache_scheme": kv_cache_scheme,
|
||||||
"global_compression_ratio": global_compression_ratio,
|
"global_compression_ratio": global_compression_ratio,
|
||||||
"ignore": ignore,
|
"ignore": ignore,
|
||||||
|
"run_compressed": run_compressed,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1149,6 +1156,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
|
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "quantization_config" in config_dict:
|
if "quantization_config" in config_dict:
|
||||||
@ -1200,6 +1208,9 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
|
|||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|
||||||
|
def get_loading_attributes(self):
|
||||||
|
return {"run_compressed": self.run_compressed}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FbgemmFp8Config(QuantizationConfigMixin):
|
class FbgemmFp8Config(QuantizationConfigMixin):
|
||||||
|
@ -0,0 +1,80 @@
|
|||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.testing_utils import require_compressed_tensors, require_torch
|
||||||
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@require_compressed_tensors
|
||||||
|
@require_torch
|
||||||
|
class CompressedTensorsTest(unittest.TestCase):
|
||||||
|
model_sparse_uncompressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_uncompressed"
|
||||||
|
model_sparse_compressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_compressed"
|
||||||
|
|
||||||
|
prompt = "Paris is the capital of which country?"
|
||||||
|
|
||||||
|
stubs = [model_sparse_uncompressed, model_sparse_compressed]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def test_compressed_uncompressed_model_shapes(self):
|
||||||
|
"""
|
||||||
|
Check that the weights are the same between
|
||||||
|
uncompressed and compressed-decompressed model
|
||||||
|
Sparse compressed modules' weights are "packed" and shape/value will
|
||||||
|
differ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _has_nested_attr(obj, attr_path):
|
||||||
|
attrs = attr_path.split(".")
|
||||||
|
for attr in attrs:
|
||||||
|
if not hasattr(obj, attr):
|
||||||
|
return None
|
||||||
|
obj = getattr(obj, attr)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
from compressed_tensors.quantization.utils import iter_named_leaf_modules
|
||||||
|
|
||||||
|
uncompressed_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_sparse_uncompressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_sparse_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, submodule in iter_named_leaf_modules(
|
||||||
|
uncompressed_model,
|
||||||
|
):
|
||||||
|
if comp_decomp_obj := _has_nested_attr(compressed_model_decompressed, name):
|
||||||
|
if hasattr(submodule, "weight"):
|
||||||
|
assert torch.equal(submodule.weight, comp_decomp_obj.weight)
|
||||||
|
|
||||||
|
def test_run_compressed_outputs_match(self):
|
||||||
|
"""Check that uncompressed and compressed-decompressed model outputs are the same"""
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
for stub in self.stubs:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(stub)
|
||||||
|
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
uncompressed_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_sparse_uncompressed,
|
||||||
|
)
|
||||||
|
output_rc_true = uncompressed_model.generate(input_ids, max_new_tokens=100)
|
||||||
|
|
||||||
|
compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_sparse_compressed,
|
||||||
|
)
|
||||||
|
output_rc_false = compressed_model_decompressed.generate(input_ids, max_new_tokens=100)
|
||||||
|
|
||||||
|
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])
|
@ -0,0 +1,94 @@
|
|||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.testing_utils import require_compressed_tensors, require_torch
|
||||||
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@require_compressed_tensors
|
||||||
|
@require_torch
|
||||||
|
class CompressedTensorsTest(unittest.TestCase):
|
||||||
|
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
|
||||||
|
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
|
||||||
|
|
||||||
|
prompt = "Paris is the capital of which country?"
|
||||||
|
|
||||||
|
stubs = [tinyllama_w4a16, tinyllama_w8a8]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def test_default_run_compressed__True(self):
|
||||||
|
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
||||||
|
from compressed_tensors.quantization.utils import iter_named_leaf_modules
|
||||||
|
|
||||||
|
for stub in self.stubs:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
stub,
|
||||||
|
)
|
||||||
|
compressed_linear_counts = 0
|
||||||
|
|
||||||
|
for _, submodule in iter_named_leaf_modules(
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
if isinstance(submodule, CompressedLinear):
|
||||||
|
compressed_linear_counts += 1
|
||||||
|
|
||||||
|
# some linear models are not compressed - ex. lm_head
|
||||||
|
assert compressed_linear_counts > 0
|
||||||
|
|
||||||
|
def test_default_run_compressed__False(self):
|
||||||
|
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
||||||
|
from compressed_tensors.quantization.utils import iter_named_leaf_modules
|
||||||
|
|
||||||
|
from transformers.utils.quantization_config import CompressedTensorsConfig
|
||||||
|
|
||||||
|
quantization_config = CompressedTensorsConfig(run_compressed=False)
|
||||||
|
|
||||||
|
for stub in self.stubs:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
stub,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
compressed_linear_counts = 0
|
||||||
|
|
||||||
|
for _, submodule in iter_named_leaf_modules(
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
if isinstance(submodule, CompressedLinear):
|
||||||
|
compressed_linear_counts += 1
|
||||||
|
|
||||||
|
# No modules should be CompressedLinear
|
||||||
|
assert compressed_linear_counts == 0
|
||||||
|
|
||||||
|
def test_run_compressed_outputs_match(self):
|
||||||
|
"""Check that run_compressed=True/False output are the same"""
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.utils.quantization_config import CompressedTensorsConfig
|
||||||
|
|
||||||
|
quantization_config = CompressedTensorsConfig(run_compressed=False)
|
||||||
|
|
||||||
|
for stub in self.stubs:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(stub)
|
||||||
|
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
|
||||||
|
stub,
|
||||||
|
)
|
||||||
|
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
|
||||||
|
|
||||||
|
model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
|
||||||
|
stub,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
|
||||||
|
|
||||||
|
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])
|
Loading…
Reference in New Issue
Block a user