mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[bnb
] Fine-tuning HF 8-bit models (#21290)
* force `memory_efficient_backward=True` * enhancements - trainer support - add new flag * some changes - internal changes in `Trainer` - small refactor * make quality * Fixes - add new testing util - add new test - change test in Trainer * fix CI test * educate users on how to ft 8bit models * more checks * fix `logger` error * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * adapt from review * fix * add comment * use return instead --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
67a3920d85
commit
8298e4ec02
@ -73,6 +73,7 @@ from .utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .utils.import_utils import importlib_metadata
|
||||
from .utils.versions import require_version_core
|
||||
|
||||
|
||||
@ -2439,6 +2440,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
|
||||
)
|
||||
|
||||
# training in 8-bit is only available in 0.37.0+
|
||||
model._is_int8_training_enabled = version.parse(
|
||||
importlib_metadata.version("bitsandbytes")
|
||||
) >= version.parse("0.37.0")
|
||||
|
||||
if isinstance(device_map, str):
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
|
||||
|
@ -368,10 +368,18 @@ class Trainer:
|
||||
|
||||
# At this stage the model is already loaded
|
||||
if getattr(model, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"The model you want to train is loaded in 8-bit precision. "
|
||||
"Training an 8-bit model is not supported yet. "
|
||||
)
|
||||
if getattr(model, "_is_int8_training_enabled", False):
|
||||
logger.info(
|
||||
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
||||
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
||||
" check "
|
||||
" the examples in https://github.com/huggingface/peft for more details."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
|
||||
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||
)
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_ddp = None
|
||||
@ -458,7 +466,7 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if self.place_model_on_device:
|
||||
if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
|
||||
self._move_model_to_device(model, args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
|
@ -16,6 +16,8 @@ import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@ -33,10 +35,30 @@ from transformers.testing_utils import (
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@ -335,3 +357,44 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
|
||||
# Second real batch
|
||||
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
self.model_name = "facebook/opt-350m"
|
||||
super().setUp()
|
||||
|
||||
def test_training(self):
|
||||
if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"):
|
||||
return
|
||||
|
||||
# Step 1: freeze all parameters
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False # freeze the model - train adapters later
|
||||
if param.ndim == 1:
|
||||
# cast the small parameters (e.g. layernorm) to fp32 for stability
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
# Step 2: add adapters
|
||||
for _, module in model.named_modules():
|
||||
if "OPTAttention" in repr(type(module)):
|
||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||
|
||||
# Step 3: dummy batch
|
||||
batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)
|
||||
|
||||
# Step 4: Check if the gradient is not None
|
||||
with torch.cuda.amp.autocast():
|
||||
out = model.forward(**batch)
|
||||
out.logits.norm().backward()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
self.assertTrue(module.adapter[1].weight.grad is not None)
|
||||
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
self.assertTrue(module.weight.grad is None)
|
||||
|
Loading…
Reference in New Issue
Block a user