mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
ensure banned_mask and indices in same device (#23901)
* ensure banned_mask and indices in same device * ensure banned_mask and indices in same device switch the order in which indices and banned_mask are created and create banned_mask on the proper device
This commit is contained in:
parent
d68d6665f9
commit
d99f11e898
@ -649,8 +649,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
|
||||
else:
|
||||
if banned_mask_list:
|
||||
banned_mask = torch.LongTensor(banned_mask_list)
|
||||
indices = torch.ones(len(banned_mask))
|
||||
indices = torch.ones(len(banned_mask_list))
|
||||
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
|
Loading…
Reference in New Issue
Block a user