From 481b953170ca66f7999b1d448cfd2a814dd28c35 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 5 Jun 2025 21:19:07 +0200 Subject: [PATCH] Fix `return_dict=False` giving errors in a few VLM models (#38519) update Co-authored-by: ydshieh --- src/transformers/models/chameleon/modeling_chameleon.py | 4 +--- src/transformers/models/kosmos2/modeling_kosmos2.py | 3 --- src/transformers/models/llava_next/modeling_llava_next.py | 4 +--- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 8 +------- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 8 +------- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 4 +--- 6 files changed, 5 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 5e39d2515e0..5644f68baa9 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1243,7 +1243,6 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -1277,7 +1276,6 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1290,7 +1288,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index d25f100a234..ebf578fedfc 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1809,7 +1809,6 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]: r""" @@ -1868,7 +1867,6 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_model_output = None projection_attentions = None @@ -1880,7 +1878,6 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index d0ea1bb233c..e5b597e819e 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -604,7 +604,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], @@ -645,7 +644,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -668,7 +666,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index b9d8985ebbc..1b4cfea5b66 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1525,7 +1525,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -1588,7 +1587,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, @@ -1604,7 +1602,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, ) @@ -1616,10 +1614,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index e38f2952144..f293f5c769c 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -770,7 +770,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -833,7 +832,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, @@ -849,7 +847,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, ) @@ -861,10 +859,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4581ed08279..a995d064b52 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1409,7 +1409,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -1469,7 +1468,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, @@ -1484,7 +1482,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, cache_position=cache_position, **kwargs, )