mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix redundant code in Janus (#38826)
* minor mistake * modify return statements
This commit is contained in:
parent
d2fd3868bb
commit
925da8ac56
@ -1007,9 +1007,8 @@ class JanusVQVAE(JanusPreTrainedModel):
|
|||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
quant, embedding_loss, indices = self.encode(pixel_values)
|
quant, embedding_loss, indices = self.encode(pixel_values)
|
||||||
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
||||||
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
|
||||||
|
|
||||||
return output
|
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||||
|
|
||||||
|
|
||||||
class JanusVQVAEAlignerMLP(nn.Module):
|
class JanusVQVAEAlignerMLP(nn.Module):
|
||||||
@ -1151,7 +1150,7 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = JanusBaseModelOutputWithPast(
|
return JanusBaseModelOutputWithPast(
|
||||||
last_hidden_state=lm_output.last_hidden_state,
|
last_hidden_state=lm_output.last_hidden_state,
|
||||||
past_key_values=lm_output.past_key_values,
|
past_key_values=lm_output.past_key_values,
|
||||||
hidden_states=lm_output.hidden_states,
|
hidden_states=lm_output.hidden_states,
|
||||||
@ -1159,8 +1158,6 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
image_hidden_states=image_embeds if pixel_values is not None else None,
|
image_hidden_states=image_embeds if pixel_values is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
||||||
@ -1249,7 +1246,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||||
|
|
||||||
output = JanusCausalLMOutputWithPast(
|
return JanusCausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
@ -1257,7 +1254,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
image_hidden_states=outputs.image_hidden_states,
|
image_hidden_states=outputs.image_hidden_states,
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -49,7 +49,6 @@ from ..blip_2.modeling_blip_2 import Blip2VisionModel
|
|||||||
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
|
from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
|
||||||
from ..chameleon.modeling_chameleon import (
|
from ..chameleon.modeling_chameleon import (
|
||||||
ChameleonVQVAE,
|
ChameleonVQVAE,
|
||||||
ChameleonVQVAEEncoder,
|
|
||||||
ChameleonVQVAEEncoderAttnBlock,
|
ChameleonVQVAEEncoderAttnBlock,
|
||||||
ChameleonVQVAEEncoderConvDownsample,
|
ChameleonVQVAEEncoderConvDownsample,
|
||||||
ChameleonVQVAEEncoderResnetBlock,
|
ChameleonVQVAEEncoderResnetBlock,
|
||||||
@ -656,9 +655,9 @@ class JanusVQVAEMidBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class JanusVQVAEEncoder(ChameleonVQVAEEncoder, nn.Module):
|
class JanusVQVAEEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
nn.Module.__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_resolutions = len(config.channel_multiplier)
|
self.num_resolutions = len(config.channel_multiplier)
|
||||||
self.num_res_blocks = config.num_res_blocks
|
self.num_res_blocks = config.num_res_blocks
|
||||||
@ -845,9 +844,8 @@ class JanusVQVAE(ChameleonVQVAE):
|
|||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
quant, embedding_loss, indices = self.encode(pixel_values)
|
quant, embedding_loss, indices = self.encode(pixel_values)
|
||||||
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
decoded_pixel_values = self.decode(indices.view(batch_size, -1))
|
||||||
output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
|
||||||
|
|
||||||
return output
|
return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
|
||||||
|
|
||||||
|
|
||||||
class JanusVQVAEAlignerMLP(nn.Module):
|
class JanusVQVAEAlignerMLP(nn.Module):
|
||||||
@ -989,7 +987,7 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = JanusBaseModelOutputWithPast(
|
return JanusBaseModelOutputWithPast(
|
||||||
last_hidden_state=lm_output.last_hidden_state,
|
last_hidden_state=lm_output.last_hidden_state,
|
||||||
past_key_values=lm_output.past_key_values,
|
past_key_values=lm_output.past_key_values,
|
||||||
hidden_states=lm_output.hidden_states,
|
hidden_states=lm_output.hidden_states,
|
||||||
@ -997,8 +995,6 @@ class JanusModel(JanusPreTrainedModel):
|
|||||||
image_hidden_states=image_embeds if pixel_values is not None else None,
|
image_hidden_states=image_embeds if pixel_values is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
|
||||||
@ -1087,7 +1083,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||||
|
|
||||||
output = JanusCausalLMOutputWithPast(
|
return JanusCausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
@ -1095,7 +1091,6 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
image_hidden_states=outputs.image_hidden_states,
|
image_hidden_states=outputs.image_hidden_states,
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user