BLIP: this is correct now (#35081)

this is correct now
This commit is contained in:
Raushan Turganbay 2024-12-05 16:30:09 +01:00 committed by GitHub
parent 50189e36a6
commit e682c17e4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 4 deletions

View File

@ -2311,7 +2311,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

View File

@ -1593,7 +1593,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

View File

@ -1628,7 +1628,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

View File

@ -441,7 +441,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)