Fix QA example (#30580)

* Handle cases when CLS token is absent

* Use BOS token as a fallback
This commit is contained in:
Matt 2024-05-01 08:43:02 +01:00 committed by GitHub
parent 4b4da18f53
commit 1e05671d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 6 deletions

View File

@ -434,7 +434,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)

View File

@ -417,7 +417,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
@ -534,7 +539,12 @@ def main():
for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).

View File

@ -444,7 +444,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
@ -563,7 +568,12 @@ def main():
for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).

View File

@ -513,7 +513,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)