mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix last instances of kbit -> quantized (#23797)
This commit is contained in:
parent
38dbbc2640
commit
9fea71b465
@ -2237,7 +2237,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
|
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
|
||||||
"requirements of `bitsandbytes` to enable model loading in mixed kbit. "
|
"requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
|
||||||
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
|
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
|
||||||
" torch_dtype=torch.float16 to remove this warning."
|
" torch_dtype=torch.float16 to remove this warning."
|
||||||
)
|
)
|
||||||
@ -2683,7 +2683,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
# training in 8-bit is only available in 0.37.0+
|
# training in 8-bit is only available in 0.37.0+
|
||||||
model._is_kbit_training_enabled = version.parse(
|
model._is_quantized_training_enabled = version.parse(
|
||||||
importlib_metadata.version("bitsandbytes")
|
importlib_metadata.version("bitsandbytes")
|
||||||
) >= version.parse("0.37.0")
|
) >= version.parse("0.37.0")
|
||||||
|
|
||||||
|
@ -403,8 +403,8 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# At this stage the model is already loaded
|
# At this stage the model is already loaded
|
||||||
if getattr(model, "is_loaded_in_kbit", False):
|
if getattr(model, "is_quantized", False):
|
||||||
if getattr(model, "_is_kbit_training_enabled", False):
|
if getattr(model, "_is_quantized_training_enabled", False):
|
||||||
logger.info(
|
logger.info(
|
||||||
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
||||||
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
||||||
|
Loading…
Reference in New Issue
Block a user