p_mask in SQuAD pre-processing (#4049)

* Better p_mask building

* Adressing @mfuntowicz comments
This commit is contained in:
Lysandre Debut 2020-05-14 17:07:52 -04:00 committed by GitHub
parent 84894974bd
commit 7defc6670f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -195,18 +195,22 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1)
# Original TF implem also keep the classification token (set to 0)
p_mask = np.ones_like(span["token_type_ids"])
if tokenizer.padding_side == "right":
# Limit positive values to one
p_mask = 1 - p_mask
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
else:
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
special_token_indices = np.asarray(
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
).nonzero()
# Set the CLS index to '0'
p_mask[pad_token_indices] = 1
p_mask[special_token_indices] = 1
# Set the cls index to 0: the CLS index can be used for impossible answers
p_mask[cls_index] = 0
span_is_impossible = example.is_impossible