[static cache] fix device map per layer in VLMs (#38488)

return lm as decoder
This commit is contained in:
Raushan Turganbay 2025-06-20 13:49:29 +02:00 committed by GitHub
parent aa42987c1e
commit ff95974bc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 162 additions and 36 deletions

View File

@ -1056,6 +1056,12 @@ class AriaModel(AriaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -1220,10 +1226,10 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -211,6 +211,12 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -389,10 +395,10 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -1451,6 +1451,12 @@ class Emu3Model(Emu3PreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.text_model = decoder
def get_decoder(self):
return self.text_model
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
@ -1599,10 +1605,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
# Make modules available throught conditional class for BC
@property

View File

@ -938,6 +938,12 @@ class Emu3Model(Emu3PreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.text_model = decoder
def get_decoder(self):
return self.text_model
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
@ -1086,10 +1092,10 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
# Make modules available throught conditional class for BC
@property

View File

@ -86,6 +86,12 @@ class FuyuModel(FuyuPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,

View File

@ -829,6 +829,12 @@ class Gemma3Model(Gemma3PreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
@ -1014,10 +1020,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)

View File

@ -637,6 +637,12 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -757,10 +763,10 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -627,6 +627,12 @@ class InternVLModel(InternVLPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -878,10 +884,10 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -181,6 +181,12 @@ class LlavaModel(LlavaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -371,10 +377,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -294,6 +294,12 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -569,10 +575,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
return self.model.pack_image_features(

View File

@ -348,6 +348,12 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -701,10 +707,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
return self.model.pack_image_features(

View File

@ -350,6 +350,12 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -742,10 +748,10 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
return self.model.pack_image_features(

View File

@ -248,6 +248,12 @@ class Mistral3Model(Mistral3PreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -407,10 +413,10 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self,

View File

@ -1641,6 +1641,12 @@ class MllamaModel(MllamaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
@can_return_tuple
@auto_docstring
def forward(
@ -1792,10 +1798,10 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
# Make modules available throught conditional class for BC
@property

View File

@ -173,6 +173,12 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _update_causal_mask(
self,
attention_mask,
@ -418,10 +424,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)

View File

@ -1847,6 +1847,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):

View File

@ -2269,6 +2269,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):

View File

@ -1067,6 +1067,12 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1498,10 +1504,10 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None

View File

@ -1033,6 +1033,12 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1382,10 +1388,10 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None

View File

@ -202,6 +202,12 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
@ -444,10 +450,10 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder()
def get_image_features(
self,

View File

@ -182,6 +182,12 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None
):
@ -327,10 +333,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model
return self.model.get_decoder
def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None