diff --git a/transformers/pipelines.py b/transformers/pipelines.py index 1e2f035d9f8..eec49323215 100755 --- a/transformers/pipelines.py +++ b/transformers/pipelines.py @@ -188,14 +188,18 @@ class QuestionAnsweringPipeline(Pipeline): start, end = start.cpu().numpy(), end.cpu().numpy() answers = [] - for i, (example, feature, start_, end_) in enumerate(zip(texts, features, start, end)): - start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1) - + for (example, feature, start_, end_) in zip(texts, features, start, end): # Normalize logits and spans to retrieve the answer start_ = np.exp(start_) / np.sum(np.exp(start_)) end_ = np.exp(end_) / np.sum(np.exp(end_)) - starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len']) + # Mask padding and question + start_, end_ = start_ * np.abs(np.array(feature.p_mask) - 1), end_ * np.abs(np.array(feature.p_mask) - 1) + + # Mask CLS + start_[0] = end_[0] = 0 + + starts, ends, scores = self.decode(start_, end_, kwargs['topk'], kwargs['max_answer_len']) char_to_word = np.array(example.char_to_word_offset) # Convert the answer (tokens) back to the original text