Fix position embeddings singular/plural (#33678)

* fix position embeddings

* [run-slow] blip, blip_2, instructblip, instructblipvideo

* fix init

* [run-slow] blip, blip_2, instructblip, instructblipvideo

* fix copies

* [run-slow] blip, blip_2, instructblip, instructblipvideo

* [run-slow] blip, blip_2, instructblip, instructblipvideo

* handle exception where list + tensors are cat'd

* [run-slow] blip, blip_2, instructblip, instructblipvideo

* add missing default

* [run-slow] blip, blip_2, instructblip, instructblipvideo
This commit is contained in:
Pablo Montalvo 2024-09-26 19:07:00 +02:00 committed by GitHub
parent 77b47e6645
commit 9f97c39384
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 21 additions and 25 deletions

View File

@ -233,7 +233,6 @@ class BlipVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
@ -245,14 +244,14 @@ class BlipVisionEmbeddings(nn.Module):
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]
dim = embeddings.shape[-1]

View File

@ -200,7 +200,6 @@ class Blip2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
@ -212,14 +211,14 @@ class Blip2VisionEmbeddings(nn.Module):
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]
dim = embeddings.shape[-1]

View File

@ -104,7 +104,6 @@ class InstructBlipVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
@ -116,14 +115,14 @@ class InstructBlipVisionEmbeddings(nn.Module):
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]
dim = embeddings.shape[-1]

View File

@ -122,8 +122,10 @@ class InstructBlipProcessor(ProcessorMixin):
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
# we have to concatenate lists - so we keep track of return_tensors here
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
if self.num_query_tokens is not None and images is not None:
@ -145,9 +147,7 @@ class InstructBlipProcessor(ProcessorMixin):
)
# cast to desired return tensors type after concatenating
text_encoding = BatchEncoding(
text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")
)
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
encoding.update(text_encoding)
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])

View File

@ -111,7 +111,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
@ -123,14 +122,14 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]
dim = embeddings.shape[-1]