Update clip loss calculation (#13217)

* Update clip loss calculation

Hello, I'm the author of the blog you took the snippet from. I think this way of calculating is possibly slightly more accurate for calculation.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sachin Abeywardana 2021-09-02 16:45:56 +10:00 committed by GitHub
parent 0a22335e66
commit 872e6be03d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -61,14 +61,13 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
return -neg_ce.mean()
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity, dim=0)
image_loss = contrastive_loss(similarity, dim=1)
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
return (caption_loss + image_loss) / 2.0