Merge pull request #2276 from ShnitzelKiller/scatterfix

fix error due to wrong argument name to Tensor.scatter()
This commit is contained in:
Thomas Wolf 2019-12-23 12:19:48 +01:00 committed by GitHub
commit e4e2a666c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -958,7 +958,9 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, source=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits