[Trainer] Refactor trainer + bnb logic (#26248)

* refactor trainer + bnb logic

* remove logger.info

* oops
This commit is contained in:
Younes Belkada 2023-09-20 17:38:59 +02:00 committed by GitHub
parent f94c9b3d86
commit 0b5024ce72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -402,19 +402,23 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
# At this stage the model is already loaded
if getattr(model, "is_quantized", False) and not getattr(model, "_hf_peft_config_loaded", False):
if getattr(model, "_is_quantized_training_enabled", False):
logger.info(
"The model is quantized. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
" check the examples in https://github.com/huggingface/peft for more details."
)
else:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
if _is_quantized_and_base_model and not _is_peft_model:
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details"
)
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False):
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
# Setup Sharded DDP training
self.sharded_ddp = None