mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fixed some rebase stuff
This commit is contained in:
parent
1d56648f3d
commit
6be15d1230
@ -2091,9 +2091,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
else:
|
else:
|
||||||
special_image_mask = input_ids == self.config.image_token_id
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
|
|
||||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
special_image_mask = (
|
||||||
|
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||||
|
)
|
||||||
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
||||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(special_image_mask, language_model_inputs)
|
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||||
|
special_image_mask, language_model_inputs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||||
@ -2231,11 +2235,21 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||||
# otherwise we expand manually by concatenating
|
# otherwise we expand manually by concatenating
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
if getattr(self.config, "image_token_id", None) is not None:
|
||||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
if input_ids is None:
|
||||||
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_image_mask = special_image_mask.all(-1)
|
||||||
|
else:
|
||||||
|
special_image_mask = input_ids == self.config.image_token_id
|
||||||
|
|
||||||
inputs_embeds = inputs_embeds.to(language_model_inputs.device)
|
special_image_mask = (
|
||||||
special_image_mask = special_image_mask.to(language_model_inputs.device)
|
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
|
||||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
)
|
||||||
|
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
|
||||||
|
special_image_mask, language_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
attention_mask = attention_mask.to(language_attention_mask.device)
|
attention_mask = attention_mask.to(language_attention_mask.device)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user