[bnb] Fine-tuning HF 8-bit models (#21290)

* force `memory_efficient_backward=True`

* enhancements

- trainer support
- add new flag

* some changes

- internal changes in `Trainer`
- small refactor

* make quality

* Fixes

- add new testing util
- add new test
- change test in Trainer

* fix CI test

* educate users on how to ft 8bit models

* more checks

* fix `logger` error

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* adapt from review

* fix

* add comment

* use return instead

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-02-02 16:39:23 +01:00 committed by GitHub
parent 67a3920d85
commit 8298e4ec02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 5 deletions

View File

@ -73,6 +73,7 @@ from .utils import (
logging,
replace_return_docstrings,
)
from .utils.import_utils import importlib_metadata
from .utils.versions import require_version_core
@ -2439,6 +2440,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
# training in 8-bit is only available in 0.37.0+
model._is_int8_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes")
) >= version.parse("0.37.0")
if isinstance(device_map, str):
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")

View File

@ -368,10 +368,18 @@ class Trainer:
# At this stage the model is already loaded
if getattr(model, "is_loaded_in_8bit", False):
raise ValueError(
"The model you want to train is loaded in 8-bit precision. "
"Training an 8-bit model is not supported yet. "
)
if getattr(model, "_is_int8_training_enabled", False):
logger.info(
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
" check "
" the examples in https://github.com/huggingface/peft for more details."
)
else:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
# Setup Sharded DDP training
self.sharded_ddp = None
@ -458,7 +466,7 @@ class Trainer:
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
if self.place_model_on_device:
if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
self._move_model_to_device(model, args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs

View File

@ -16,6 +16,8 @@ import gc
import tempfile
import unittest
from packaging import version
from transformers import (
AutoModel,
AutoModelForCausalLM,
@ -33,10 +35,30 @@ from transformers.testing_utils import (
require_torch_multi_gpu,
slow,
)
from transformers.utils.versions import importlib_metadata
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
@require_bitsandbytes
@ -335,3 +357,44 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
# Second real batch
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
class MixedInt8TestTraining(BaseMixedInt8Test):
def setUp(self):
self.model_name = "facebook/opt-350m"
super().setUp()
def test_training(self):
if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"):
return
# Step 1: freeze all parameters
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)
# Step 2: add adapters
for _, module in model.named_modules():
if "OPTAttention" in repr(type(module)):
module.q_proj = LoRALayer(module.q_proj, rank=16)
module.k_proj = LoRALayer(module.k_proj, rank=16)
module.v_proj = LoRALayer(module.v_proj, rank=16)
# Step 3: dummy batch
batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)
# Step 4: Check if the gradient is not None
with torch.cuda.amp.autocast():
out = model.forward(**batch)
out.logits.norm().backward()
for module in model.modules():
if isinstance(module, LoRALayer):
self.assertTrue(module.adapter[1].weight.grad is not None)
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None)