mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Make ClipSeg compatible with model parallelism (#22844)
This commit is contained in:
parent
5bb4ec6233
commit
84a6570e7b
@ -1480,6 +1480,8 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
# move labels to the correct device to enable PP
|
||||||
|
labels = labels.to(logits.device)
|
||||||
loss_fn = nn.BCEWithLogitsLoss()
|
loss_fn = nn.BCEWithLogitsLoss()
|
||||||
loss = loss_fn(logits, labels)
|
loss = loss_fn(logits, labels)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user