mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Remove float64 cast for OwlVit and OwlV2 to support MPS device (#31071)
Remove float64
This commit is contained in:
parent
936ab7bae5
commit
c31473ed44
@ -1276,7 +1276,6 @@ class Owlv2ClassPredictionHead(nn.Module):
|
||||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = pred_logits.to(torch.float64)
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
|
@ -1257,7 +1257,6 @@ class OwlViTClassPredictionHead(nn.Module):
|
||||
if query_mask.ndim > 1:
|
||||
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||||
|
||||
pred_logits = pred_logits.to(torch.float64)
|
||||
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||||
pred_logits = pred_logits.to(torch.float32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user