Added parallel device usage for GPT-J (#22713)

This commit is contained in:
jprivera44 2023-04-12 04:31:27 -07:00 committed by GitHub
parent b76e6ebd44
commit 17503b00ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

1
src/transformers/models/gptj/modeling_gptj.py Executable file → Normal file
View File

@ -1012,6 +1012,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"