Moved labels to enable parallelism pipeline in Luke model (#22909)

This commit is contained in:
SUSHMANTH REDDY 2023-04-21 14:49:15 +05:30 committed by GitHub
parent 397720fb14
commit aab14120d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1370,6 +1370,8 @@ class LukeForMaskedLM(LukePreTrainedModel):
mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state)
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
if loss is None:
loss = mlm_loss
@ -1505,6 +1507,8 @@ class LukeForEntityClassification(LukePreTrainedModel):
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels)
else:
@ -1623,6 +1627,8 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels)
else:
@ -1765,6 +1771,8 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 2:
@ -1862,6 +1870,8 @@ class LukeForSequenceClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
@ -1975,6 +1985,8 @@ class LukeForTokenClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
@ -2216,6 +2228,8 @@ class LukeForMultipleChoice(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)