Make ClipSeg compatible with model parallelism (#22844)

This commit is contained in:
Youssef Adarrab 2023-04-18 23:31:59 +00:00 committed by GitHub
parent 5bb4ec6233
commit 84a6570e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1480,6 +1480,8 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
loss = None
if labels is not None:
# move labels to the correct device to enable PP
labels = labels.to(logits.device)
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels)