Fix redundant code in Janus (#38826)

* minor mistake

* modify return statements
This commit is contained in:
Yaswanth Gali 2025-06-16 12:23:59 +05:30 committed by GitHub
parent d2fd3868bb
commit 925da8ac56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 17 deletions

View File

@ -1007,9 +1007,8 @@ class JanusVQVAE(JanusPreTrainedModel):
batch_size = pixel_values.shape[0]
quant, embedding_loss, indices = self.encode(pixel_values)
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
return output
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
class JanusVQVAEAlignerMLP(nn.Module):
@ -1151,7 +1150,7 @@ class JanusModel(JanusPreTrainedModel):
**kwargs,
)
output = JanusBaseModelOutputWithPast(
return JanusBaseModelOutputWithPast(
last_hidden_state=lm_output.last_hidden_state,
past_key_values=lm_output.past_key_values,
hidden_states=lm_output.hidden_states,
@ -1159,8 +1158,6 @@ class JanusModel(JanusPreTrainedModel):
image_hidden_states=image_embeds if pixel_values is not None else None,
)
return output
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
@ -1249,7 +1246,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
output = JanusCausalLMOutputWithPast(
return JanusCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
@ -1257,7 +1254,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
return output
def prepare_inputs_for_generation(
self,

View File

@ -49,7 +49,6 @@ from ..blip_2.modeling_blip_2 import Blip2VisionModel
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
from ..chameleon.modeling_chameleon import (
ChameleonVQVAE,
ChameleonVQVAEEncoder,
ChameleonVQVAEEncoderAttnBlock,
ChameleonVQVAEEncoderConvDownsample,
ChameleonVQVAEEncoderResnetBlock,
@ -656,9 +655,9 @@ class JanusVQVAEMidBlock(nn.Module):
return hidden_states
class JanusVQVAEEncoder(ChameleonVQVAEEncoder, nn.Module):
class JanusVQVAEEncoder(nn.Module):
def __init__(self, config):
nn.Module.__init__()
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
@ -845,9 +844,8 @@ class JanusVQVAE(ChameleonVQVAE):
batch_size = pixel_values.shape[0]
quant, embedding_loss, indices = self.encode(pixel_values)
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
return output
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
class JanusVQVAEAlignerMLP(nn.Module):
@ -989,7 +987,7 @@ class JanusModel(JanusPreTrainedModel):
**kwargs,
)
output = JanusBaseModelOutputWithPast(
return JanusBaseModelOutputWithPast(
last_hidden_state=lm_output.last_hidden_state,
past_key_values=lm_output.past_key_values,
hidden_states=lm_output.hidden_states,
@ -997,8 +995,6 @@ class JanusModel(JanusPreTrainedModel):
image_hidden_states=image_embeds if pixel_values is not None else None,
)
return output
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
@ -1087,7 +1083,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
output = JanusCausalLMOutputWithPast(
return JanusCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
@ -1095,7 +1091,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
return output
def prepare_inputs_for_generation(
self,