mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update SiglipVisionEmbeddings.forward to cast input to correct dtype before embedding it.
This commit is contained in:
parent
114dd812dd
commit
c443d8d536
@ -308,7 +308,8 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
|
Loading…
Reference in New Issue
Block a user