Fix last instances of kbit -> quantized (#23797)

This commit is contained in:
Sylvain Gugger 2023-05-31 05:38:20 -04:00 committed by GitHub
parent 38dbbc2640
commit 9fea71b465
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -2237,7 +2237,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
logger.info(
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
"requirements of `bitsandbytes` to enable model loading in mixed kbit. "
"requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.float16 to remove this warning."
)
@ -2683,7 +2683,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# training in 8-bit is only available in 0.37.0+
model._is_kbit_training_enabled = version.parse(
model._is_quantized_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes")
) >= version.parse("0.37.0")

View File

@ -403,8 +403,8 @@ class Trainer:
)
# At this stage the model is already loaded
if getattr(model, "is_loaded_in_kbit", False):
if getattr(model, "_is_kbit_training_enabled", False):
if getattr(model, "is_quantized", False):
if getattr(model, "_is_quantized_training_enabled", False):
logger.info(
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"