[Flava] Fix flava torch.distributed.nn.functional import all_gather issue (#23108)

* fix flava `torch.distributed.nn.functional import all_gather` issue

* more comments
This commit is contained in:
Younes Belkada 2023-05-02 15:35:57 +02:00 committed by GitHub
parent c6c6658499
commit 4baa34c18f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1693,8 +1693,10 @@ class FlavaGlobalContrastiveHead(nn.Module):
world_size = torch.distributed.get_world_size()
if self.global_backprop_contrastive:
image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings)
text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings)
# `torch.distributed.nn.functional.all_gather` does backprop on all active workers
# whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
else:
image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]