Fix test_model_parallelism for FalconModel (#24914)

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-07-19 13:18:16 +02:00 committed by GitHub
parent c035970212
commit 243b2ea3fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1040,7 +1040,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(