mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flava
] Fix flava torch.distributed.nn.functional import all_gather
issue (#23108)
* fix flava `torch.distributed.nn.functional import all_gather` issue * more comments
This commit is contained in:
parent
c6c6658499
commit
4baa34c18f
@ -1693,8 +1693,10 @@ class FlavaGlobalContrastiveHead(nn.Module):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
if self.global_backprop_contrastive:
|
||||
image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings)
|
||||
text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings)
|
||||
# `torch.distributed.nn.functional.all_gather` does backprop on all active workers
|
||||
# whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
|
||||
image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
|
||||
text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
|
||||
else:
|
||||
image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
|
||||
text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
|
||||
|
Loading…
Reference in New Issue
Block a user