mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Moved labels to enable parallelism pipeline in Luke model (#22909)
This commit is contained in:
parent
397720fb14
commit
aab14120d4
@ -1370,6 +1370,8 @@ class LukeForMaskedLM(LukePreTrainedModel):
|
||||
mlm_loss = None
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
if loss is None:
|
||||
loss = mlm_loss
|
||||
@ -1505,6 +1507,8 @@ class LukeForEntityClassification(LukePreTrainedModel):
|
||||
if labels is not None:
|
||||
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
|
||||
# cross entropy is used otherwise.
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if labels.ndim == 1:
|
||||
loss = nn.functional.cross_entropy(logits, labels)
|
||||
else:
|
||||
@ -1623,6 +1627,8 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
|
||||
if labels is not None:
|
||||
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
|
||||
# cross entropy is used otherwise.
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if labels.ndim == 1:
|
||||
loss = nn.functional.cross_entropy(logits, labels)
|
||||
else:
|
||||
@ -1765,6 +1771,8 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
|
||||
# cross entropy is used otherwise.
|
||||
if labels.ndim == 2:
|
||||
@ -1862,6 +1870,8 @@ class LukeForSequenceClassification(LukePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
@ -1975,6 +1985,8 @@ class LukeForTokenClassification(LukePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
@ -2216,6 +2228,8 @@ class LukeForMultipleChoice(LukePreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(reshaped_logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user