mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix test_model_parallelism
for FalconModel
(#24914)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
c035970212
commit
243b2ea3fd
@ -1040,7 +1040,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
Loading…
Reference in New Issue
Block a user