mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Fix redundant code in Janus (#38826)
* minor mistake * modify return statements
This commit is contained in:
parent
d2fd3868bb
commit
925da8ac56
@ -1007,9 +1007,8 @@ class JanusVQVAE(JanusPreTrainedModel):
|
||||
batch_size = pixel_values.shape[0]
|
||||
quant, embedding_loss, indices = self.encode(pixel_values)
|
||||
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
||||
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||
|
||||
return output
|
||||
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||
|
||||
|
||||
class JanusVQVAEAlignerMLP(nn.Module):
|
||||
@ -1151,7 +1150,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = JanusBaseModelOutputWithPast(
|
||||
return JanusBaseModelOutputWithPast(
|
||||
last_hidden_state=lm_output.last_hidden_state,
|
||||
past_key_values=lm_output.past_key_values,
|
||||
hidden_states=lm_output.hidden_states,
|
||||
@ -1159,8 +1158,6 @@ class JanusModel(JanusPreTrainedModel):
|
||||
image_hidden_states=image_embeds if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
||||
@ -1249,7 +1246,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
|
||||
output = JanusCausalLMOutputWithPast(
|
||||
return JanusCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1257,7 +1254,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
@ -49,7 +49,6 @@ from ..blip_2.modeling_blip_2 import Blip2VisionModel
|
||||
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
|
||||
from ..chameleon.modeling_chameleon import (
|
||||
ChameleonVQVAE,
|
||||
ChameleonVQVAEEncoder,
|
||||
ChameleonVQVAEEncoderAttnBlock,
|
||||
ChameleonVQVAEEncoderConvDownsample,
|
||||
ChameleonVQVAEEncoderResnetBlock,
|
||||
@ -656,9 +655,9 @@ class JanusVQVAEMidBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JanusVQVAEEncoder(ChameleonVQVAEEncoder, nn.Module):
|
||||
class JanusVQVAEEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__()
|
||||
super().__init__()
|
||||
|
||||
self.num_resolutions = len(config.channel_multiplier)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
@ -845,9 +844,8 @@ class JanusVQVAE(ChameleonVQVAE):
|
||||
batch_size = pixel_values.shape[0]
|
||||
quant, embedding_loss, indices = self.encode(pixel_values)
|
||||
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
||||
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||
|
||||
return output
|
||||
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||
|
||||
|
||||
class JanusVQVAEAlignerMLP(nn.Module):
|
||||
@ -989,7 +987,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = JanusBaseModelOutputWithPast(
|
||||
return JanusBaseModelOutputWithPast(
|
||||
last_hidden_state=lm_output.last_hidden_state,
|
||||
past_key_values=lm_output.past_key_values,
|
||||
hidden_states=lm_output.hidden_states,
|
||||
@ -997,8 +995,6 @@ class JanusModel(JanusPreTrainedModel):
|
||||
image_hidden_states=image_embeds if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
||||
@ -1087,7 +1083,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
|
||||
output = JanusCausalLMOutputWithPast(
|
||||
return JanusCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1095,7 +1091,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user