Move labels to the same device as logits for LlamaForSequenceClassification and Blip2 (#22596)

* (feat): Move labels to the same device as logits

* Trigger CI

* Trigger CI

* Trigger CI

* (feat): Making changes for Blip2
This commit is contained in:
Shikhar Chauhan 2023-04-07 14:23:55 +02:00 committed by GitHub
parent d59034ff6f
commit 1de8ce9ee1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 0 deletions

View File

@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()

View File

@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"