Fixed some rebase stuff

This commit is contained in:
remi-or 2025-07-02 04:59:02 -05:00
parent 1d56648f3d
commit 6be15d1230

View File

@ -2091,9 +2091,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
special_image_mask = (
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
)
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(special_image_mask, language_model_inputs)
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
special_image_mask, language_model_inputs
)
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
@ -2231,11 +2235,21 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
# if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
special_image_mask = special_image_mask.to(language_model_inputs.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
special_image_mask = (
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
)
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
special_image_mask, language_model_inputs
)
attention_mask = attention_mask.to(language_attention_mask.device)
else: