mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
p_mask in SQuAD pre-processing (#4049)
* Better p_mask building * Adressing @mfuntowicz comments
This commit is contained in:
parent
84894974bd
commit
7defc6670f
@ -195,18 +195,22 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
||||
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = np.array(span["token_type_ids"])
|
||||
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
# Original TF implem also keep the classification token (set to 0)
|
||||
p_mask = np.ones_like(span["token_type_ids"])
|
||||
if tokenizer.padding_side == "right":
|
||||
# Limit positive values to one
|
||||
p_mask = 1 - p_mask
|
||||
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
|
||||
else:
|
||||
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
||||
|
||||
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
|
||||
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
|
||||
special_token_indices = np.asarray(
|
||||
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
||||
).nonzero()
|
||||
|
||||
# Set the CLS index to '0'
|
||||
p_mask[pad_token_indices] = 1
|
||||
p_mask[special_token_indices] = 1
|
||||
|
||||
# Set the cls index to 0: the CLS index can be used for impossible answers
|
||||
p_mask[cls_index] = 0
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
|
Loading…
Reference in New Issue
Block a user