mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Merge d2a0df9e9d
into 37a239ca50
This commit is contained in:
commit
e4582ff16e
@ -363,3 +363,9 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_compileable(self) -> bool:
|
||||||
|
# Compatible with PyTorch 2.4+ for fullgraph=False.
|
||||||
|
# Requires PyTorch 2.8 nightly for fullgraph=True.
|
||||||
|
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.46.0")
|
||||||
|
@ -314,3 +314,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|||||||
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_compileable(self) -> bool:
|
||||||
|
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.46.0")
|
||||||
|
@ -831,11 +831,3 @@ class Bnb4bitCompile(unittest.TestCase):
|
|||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
cache_implementation="static",
|
cache_implementation="static",
|
||||||
)
|
)
|
||||||
with self.assertRaises(Exception):
|
|
||||||
# overwrite property
|
|
||||||
object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True)
|
|
||||||
self.model_4bit.generate(
|
|
||||||
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
|
|
||||||
max_new_tokens=10,
|
|
||||||
cache_implementation="static",
|
|
||||||
)
|
|
||||||
|
@ -1005,11 +1005,3 @@ class Bnb8bitCompile(unittest.TestCase):
|
|||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
cache_implementation="static",
|
cache_implementation="static",
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
|
||||||
object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True)
|
|
||||||
self.model_8bit.generate(
|
|
||||||
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
|
|
||||||
max_new_tokens=10,
|
|
||||||
cache_implementation="static",
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user