[doc] fix bugs in how_to_hack_models.md (#38198)

fix several bugs
This commit is contained in:
Fanli Lin 2025-05-20 01:37:54 +08:00 committed by GitHub
parent f524439cc5
commit 9ecee14378
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
```py
from transformers import SamModel
from transformers.models.sam import modeling_sam
# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# replace the attention class in the vision_encoder module
for layer in model.vision_encoder.layers:
if hasattr(layer, "attn"):
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
```
## LoRA
@ -138,7 +134,7 @@ config = LoraConfig(
# apply LoRA to q and v
target_modules=["q", "v"],
lora_dropout=0.1,
task_type="mask-generation"
task_type="FEATURE_EXTRACTION"
)
```
@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
```py
model.print_trainable_parameters()
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
```