Filter out entity for NER task.

This commit is contained in:
Morgan Funtowicz 2019-12-20 09:30:37 +01:00
parent e4baa68ddb
commit 9d0d1cd339

View File

@ -450,11 +450,12 @@ class NerPipeline(Pipeline):
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
label_idx = score.argmax()
answer += [{
'word': words[idx],
'score': score[label_idx].item(),
'entity': self.model.config.id2label[label_idx]
}]
if label_idx > 0:
answer += [{
'word': words[idx],
'score': score[label_idx].item(),
'entity': self.model.config.id2label[label_idx]
}]
# Update token start
token_start += len(list(word))