mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Siglip: add _no_split_module
(#31566)
* device-map siglip * move split modules to PretrainedSigLip
This commit is contained in:
parent
74b92c6256
commit
7e86cb6c6f
@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
||||
config_class = SiglipConfig
|
||||
base_model_prefix = "siglip"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"SiglipTextEmbeddings",
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipVisionEmbeddings",
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
|
||||
class SiglipTextModel(SiglipPreTrainedModel):
|
||||
config_class = SiglipTextConfig
|
||||
|
||||
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
|
||||
|
||||
def __init__(self, config: SiglipTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = SiglipTextTransformer(config)
|
||||
@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"]
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__(config)
|
||||
@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
|
||||
logits_per_text = (
|
||||
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
|
||||
+ self.logit_bias
|
||||
)
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
loss = None
|
||||
|
@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
|
||||
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
|
||||
def setUp(self):
|
||||
@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
# MP works but offload doesn't work when the MultiheadAttention is offloaded
|
||||
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
||||
# in the dispatch_model function
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_safetensors = False
|
||||
test_disk_offload_bin = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SiglipForImageClassificationModelTester(self)
|
||||
|
Loading…
Reference in New Issue
Block a user