mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update clip loss calculation (#13217)
* Update clip loss calculation Hello, I'm the author of the blog you took the snippet from. I think this way of calculating is possibly slightly more accurate for calculation. * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
0a22335e66
commit
872e6be03d
@ -61,14 +61,13 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
# contrastive loss function, adapted from
|
||||
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
||||
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
|
||||
return -neg_ce.mean()
|
||||
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
||||
|
||||
|
||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||
caption_loss = contrastive_loss(similarity, dim=0)
|
||||
image_loss = contrastive_loss(similarity, dim=1)
|
||||
caption_loss = contrastive_loss(similarity)
|
||||
image_loss = contrastive_loss(similarity.T)
|
||||
return (caption_loss + image_loss) / 2.0
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user