Make sure padding, cls and another non-context tokens cannot appear in the answer.

This commit is contained in:
Morgan Funtowicz 2019-12-10 16:47:35 +01:00
parent 40a39ab650
commit 63e36007ee

View File

@ -188,14 +188,18 @@ class QuestionAnsweringPipeline(Pipeline):
start, end = start.cpu().numpy(), end.cpu().numpy() start, end = start.cpu().numpy(), end.cpu().numpy()
answers = [] answers = []
for i, (example, feature, start_, end_) in enumerate(zip(texts, features, start, end)): for (example, feature, start_, end_) in zip(texts, features, start, end):
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
# Normalize logits and spans to retrieve the answer # Normalize logits and spans to retrieve the answer
start_ = np.exp(start_) / np.sum(np.exp(start_)) start_ = np.exp(start_) / np.sum(np.exp(start_))
end_ = np.exp(end_) / np.sum(np.exp(end_)) end_ = np.exp(end_) / np.sum(np.exp(end_))
starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
# Mask padding and question
start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1)
# Mask CLS
start_[0] = end_[0] = 0
starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len'])
char_to_word = np.array(example.char_to_word_offset) char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text # Convert the answer (tokens) back to the original text