Update question_answering.py (#32208)

This commit is contained in:
Austin 2024-07-25 07:20:27 -05:00 committed by GitHub
parent f53a5dec7b
commit 1ecedf1d9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -118,7 +118,7 @@ def select_starts_ends(
max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
"""
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
undesired_tokens = np.abs(p_mask.numpy() - 1)
undesired_tokens = np.abs(np.array(p_mask) - 1)
if attention_mask is not None:
undesired_tokens = undesired_tokens & attention_mask