mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added parallel device usage for GPT-J (#22713)
This commit is contained in:
parent
b76e6ebd44
commit
17503b00ea
1
src/transformers/models/gptj/modeling_gptj.py
Executable file → Normal file
1
src/transformers/models/gptj/modeling_gptj.py
Executable file → Normal file
@ -1012,6 +1012,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(pooled_logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
Loading…
Reference in New Issue
Block a user