ensure banned_mask and indices in same device (#23901)

* ensure banned_mask and indices in same device

* ensure banned_mask and indices in same device

switch the order in which indices and banned_mask are created and create banned_mask on the proper device
This commit is contained in:
Xinyu Yang 2023-05-31 21:47:46 +08:00 committed by GitHub
parent d68d6665f9
commit d99f11e898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -649,8 +649,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
else:
if banned_mask_list:
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
indices = torch.ones(len(banned_mask_list))
banned_mask = torch.LongTensor(banned_mask_list, device=indices.device)
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]