mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Move labels to the same device as logits for LlamaForSequenceClassification and Blip2 (#22596)
* (feat): Move labels to the same device as logits * Trigger CI * Trigger CI * Trigger CI * (feat): Making changes for Blip2
This commit is contained in:
parent
d59034ff6f
commit
1de8ce9ee1
@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
|
@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
Loading…
Reference in New Issue
Block a user