mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixed some rebase stuff
This commit is contained in:
parent
1d56648f3d
commit
6be15d1230
@ -2091,9 +2091,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||
special_image_mask = (
|
||||
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||
)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(special_image_mask, language_model_inputs)
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||
special_image_mask, language_model_inputs
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
@ -2231,11 +2235,21 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
|
||||
special_image_mask = special_image_mask.to(language_model_inputs.device)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = (
|
||||
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||
)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||
special_image_mask, language_model_inputs
|
||||
)
|
||||
|
||||
attention_mask = attention_mask.to(language_attention_mask.device)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user