transformers/docs/source/en/how_to_hack_models.md
Steven Liu c0f8d055ce
[docs] Redesign (#31757)
* toctree

* not-doctested.txt

* collapse sections

* feedback

* update

* rewrite get started sections

* fixes

* fix

* loading models

* fix

* customize models

* share

* fix link

* contribute part 1

* contribute pt 2

* fix toctree

* tokenization pt 1

* Add new model (#32615)

* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* "to be not" -> "not to be" (#32636)

* "to be not" -> "not to be"

* Update sam.md

* Update trainer.py

* Update modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* fix hfoption tag

* tokenization pt. 2

* image processor

* fix toctree

* backbones

* feature extractor

* fix file name

* processor

* update not-doctested

* update

* make style

* fix toctree

* revision

* make fixup

* fix toctree

* fix

* make style

* fix hfoption tag

* pipeline

* pipeline gradio

* pipeline web server

* add pipeline

* fix toctree

* not-doctested

* prompting

* llm optims

* fix toctree

* fixes

* cache

* text generation

* fix

* chat pipeline

* chat stuff

* xla

* torch.compile

* cpu inference

* toctree

* gpu inference

* agents and tools

* gguf/tiktoken

* finetune

* toctree

* trainer

* trainer pt 2

* optims

* optimizers

* accelerate

* parallelism

* fsdp

* update

* distributed cpu

* hardware training

* gpu training

* gpu training 2

* peft

* distrib debug

* deepspeed 1

* deepspeed 2

* chat toctree

* quant pt 1

* quant pt 2

* fix toctree

* fix

* fix

* quant pt 3

* quant pt 4

* serialization

* torchscript

* scripts

* tpu

* review

* model addition timeline

* modular

* more reviews

* reviews

* fix toctree

* reviews reviews

* continue reviews

* more reviews

* modular transformers

* more review

* zamba2

* fix

* all frameworks

* pytorch

* supported model frameworks

* flashattention

* rm check_table

* not-doctested.txt

* rm check_support_list.py

* feedback

* updates/feedback

* review

* feedback

* fix

* update

* feedback

* updates

* update

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-03 10:33:46 -08:00

7.6 KiB

Customizing model components

Another way to customize a model is to modify their components, rather than writing a new model entirely, allowing you to tailor a model to your specific use case. For example, you can add new layers or optimize the attention mechanism of an architecture. Customizations are applied directly to a Transformers model so that you can continue to use features such as [Trainer], [PreTrainedModel], and the PEFT library.

This guide will show you how to customize a models attention mechanism in order to apply Low-Rank Adaptation (LoRA) to it.

Tip

The clear_import_cache utility is very useful when you're iteratively modifying and developing model code. It removes all cached Transformers modules and allows Python to reload the modified code without constantly restarting your environment.

from transformers import AutoModel
from transformers.utils.import_utils import clear_import_cache

model = AutoModel.from_pretrained("bert-base-uncased")
# modifications to model code
# clear cache to reload modified code
clear_import_cache()
# re-import to use updated code
model = AutoModel.from_pretrained("bert-base-uncased")

Attention class

Segment Anything is an image segmentation model, and it combines the query-key-value (qkv) projection in its attention mechanims. To reduce the number of trainable parameters and computational overhead, you can apply LoRA to the qkv projection. This requires splitting the qkv projection so that you can separately target the q and v with LoRA.

  1. Create a custom attention class, SamVisionAttentionSplit, by subclassing the original SamVisionAttention class. In the __init__, delete the combined qkv and create a separate linear layer for q, k and v.
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        # remove combined qkv
        del self.qkv
        # separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
  1. The _split_qkv_load_hook function splits the pretrained qkv weights into separate q, k, and v weights when loading the model to ensure compatibility with any pretrained model.
    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # mark the old qkv key for deletion
                keys_to_delete.append(key)
        
        # remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]
  1. In the forward pass, q, k, and v are computed separately while the rest of the attention mechanism remains the same.
    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

Assign the custom SamVisionAttentionSplit class to the original models SamVisionAttention module to replace it. All instances of SamVisionAttention in the model is replaced with the split attention version.

Load the model with [~PreTrainedModel.from_pretrained].

from transformers import SamModel
from transformers.models.sam import modeling_sam

# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit

# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")

LoRA

With separate q, k, and v projections, apply LoRA to q and v.

Create a LoraConfig and specify the rank r, lora_alpha, lora_dropout, task_type, and most importantly, the modules to target.

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    # apply LoRA to q and v
    target_modules=["q", "v"],
    lora_dropout=0.1,
    task_type="mask-generation"
)

Pass the model and LoraConfig to get_peft_model to apply LoRA to the model.

model = get_peft_model(model, config)

Call print_trainable_parameters to view the number of parameters you're training as a result versus the total number of parameters.

model.print_trainable_parameters()
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"