Bug Fix for issue #34294 (#34295)

Update SiglipVisionEmbeddings.forward to cast input to correct dtype before embedding it.
This commit is contained in:
fpgaminer 2024-10-31 10:51:15 -07:00 committed by GitHub
parent 114dd812dd
commit c443d8d536
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -308,7 +308,8 @@ class SiglipVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
_, _, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if interpolate_pos_encoding: