mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Manage impossible examples SQuAD v2
This commit is contained in:
parent
983c484fa2
commit
073219b43f
@ -242,6 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible
|
||||
)
|
||||
)
|
||||
return features
|
||||
@ -332,6 +333,7 @@ def squad_convert_examples_to_features(
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
||||
|
||||
if not is_training:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
@ -349,6 +351,7 @@ def squad_convert_examples_to_features(
|
||||
all_end_positions,
|
||||
all_cls_index,
|
||||
all_p_mask,
|
||||
all_is_impossible
|
||||
)
|
||||
|
||||
return features, dataset
|
||||
@ -369,6 +372,7 @@ def squad_convert_examples_to_features(
|
||||
"end_position": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
"is_impossible": ex.is_impossible
|
||||
},
|
||||
)
|
||||
|
||||
@ -376,7 +380,7 @@ def squad_convert_examples_to_features(
|
||||
gen,
|
||||
(
|
||||
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
|
||||
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32},
|
||||
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32, "is_impossible": tf.int32},
|
||||
),
|
||||
(
|
||||
{
|
||||
@ -389,6 +393,7 @@ def squad_convert_examples_to_features(
|
||||
"end_position": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([])
|
||||
},
|
||||
),
|
||||
)
|
||||
@ -658,6 +663,7 @@ class SquadFeatures(object):
|
||||
token_to_orig_map,
|
||||
start_position,
|
||||
end_position,
|
||||
is_impossible
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
@ -674,7 +680,7 @@ class SquadFeatures(object):
|
||||
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
|
||||
self.is_impossible = is_impossible
|
||||
|
||||
class SquadResult(object):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user