From 925da8ac568c804de4085f31fc08762ff9519b4e Mon Sep 17 00:00:00 2001 From: Yaswanth Gali <82788246+yaswanth19@users.noreply.github.com> Date: Mon, 16 Jun 2025 12:23:59 +0530 Subject: [PATCH] Fix redundant code in Janus (#38826) * minor mistake * modify return statements --- src/transformers/models/janus/modeling_janus.py | 10 +++------- src/transformers/models/janus/modular_janus.py | 15 +++++---------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 6455c41eb12..7084382988f 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1007,9 +1007,8 @@ class JanusVQVAE(JanusPreTrainedModel): batch_size = pixel_values.shape[0] quant, embedding_loss, indices = self.encode(pixel_values) decoded_pixel_values = self.decode(indices.view(batch_size, -1)) - output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss) - return output + return JanusVQVAEOutput(decoded_pixel_values, embedding_loss) class JanusVQVAEAlignerMLP(nn.Module): @@ -1151,7 +1150,7 @@ class JanusModel(JanusPreTrainedModel): **kwargs, ) - output = JanusBaseModelOutputWithPast( + return JanusBaseModelOutputWithPast( last_hidden_state=lm_output.last_hidden_state, past_key_values=lm_output.past_key_values, hidden_states=lm_output.hidden_states, @@ -1159,8 +1158,6 @@ class JanusModel(JanusPreTrainedModel): image_hidden_states=image_embeds if pixel_values is not None else None, ) - return output - class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] @@ -1249,7 +1246,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) - output = JanusCausalLMOutputWithPast( + return JanusCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -1257,7 +1254,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) - return output def prepare_inputs_for_generation( self, diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 599348026d1..17b6565ddcb 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -49,7 +49,6 @@ from ..blip_2.modeling_blip_2 import Blip2VisionModel from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig from ..chameleon.modeling_chameleon import ( ChameleonVQVAE, - ChameleonVQVAEEncoder, ChameleonVQVAEEncoderAttnBlock, ChameleonVQVAEEncoderConvDownsample, ChameleonVQVAEEncoderResnetBlock, @@ -656,9 +655,9 @@ class JanusVQVAEMidBlock(nn.Module): return hidden_states -class JanusVQVAEEncoder(ChameleonVQVAEEncoder, nn.Module): +class JanusVQVAEEncoder(nn.Module): def __init__(self, config): - nn.Module.__init__() + super().__init__() self.num_resolutions = len(config.channel_multiplier) self.num_res_blocks = config.num_res_blocks @@ -845,9 +844,8 @@ class JanusVQVAE(ChameleonVQVAE): batch_size = pixel_values.shape[0] quant, embedding_loss, indices = self.encode(pixel_values) decoded_pixel_values = self.decode(indices.view(batch_size, -1)) - output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss) - return output + return JanusVQVAEOutput(decoded_pixel_values, embedding_loss) class JanusVQVAEAlignerMLP(nn.Module): @@ -989,7 +987,7 @@ class JanusModel(JanusPreTrainedModel): **kwargs, ) - output = JanusBaseModelOutputWithPast( + return JanusBaseModelOutputWithPast( last_hidden_state=lm_output.last_hidden_state, past_key_values=lm_output.past_key_values, hidden_states=lm_output.hidden_states, @@ -997,8 +995,6 @@ class JanusModel(JanusPreTrainedModel): image_hidden_states=image_embeds if pixel_values is not None else None, ) - return output - class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] @@ -1087,7 +1083,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) - output = JanusCausalLMOutputWithPast( + return JanusCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, @@ -1095,7 +1091,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) - return output def prepare_inputs_for_generation( self,