mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
50189e36a6
commit
e682c17e4a
@ -2311,7 +2311,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
|
||||
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
|
@ -1593,7 +1593,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
|
||||
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
|
@ -1628,7 +1628,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
|
@ -441,7 +441,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
|
||||
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user