mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 20:18:24 +06:00
Make sure padding, cls and another non-context tokens cannot appear in the answer.
This commit is contained in:
parent
40a39ab650
commit
63e36007ee
@ -188,14 +188,18 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
start, end = start.cpu().numpy(), end.cpu().numpy()
|
start, end = start.cpu().numpy(), end.cpu().numpy()
|
||||||
|
|
||||||
answers = []
|
answers = []
|
||||||
for i, (example, feature, start_, end_) in enumerate(zip(texts, features, start, end)):
|
for (example, feature, start_, end_) in zip(texts, features, start, end):
|
||||||
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
|
|
||||||
|
|
||||||
# Normalize logits and spans to retrieve the answer
|
# Normalize logits and spans to retrieve the answer
|
||||||
start_ = np.exp(start_) / np.sum(np.exp(start_))
|
start_ = np.exp(start_) / np.sum(np.exp(start_))
|
||||||
end_ = np.exp(end_) / np.sum(np.exp(end_))
|
end_ = np.exp(end_) / np.sum(np.exp(end_))
|
||||||
starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
|
|
||||||
|
|
||||||
|
# Mask padding and question
|
||||||
|
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
|
||||||
|
|
||||||
|
# Mask CLS
|
||||||
|
start_[0] = end_[0] = 0
|
||||||
|
|
||||||
|
starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
|
||||||
char_to_word = np.array(example.char_to_word_offset)
|
char_to_word = np.array(example.char_to_word_offset)
|
||||||
|
|
||||||
# Convert the answer (tokens) back to the original text
|
# Convert the answer (tokens) back to the original text
|
||||||
|
Loading…
Reference in New Issue
Block a user