mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix pipeline NER
This commit is contained in:
parent
e37ca8e11a
commit
a241011057
@ -463,7 +463,7 @@ class NerPipeline(Pipeline):
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None, framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None, device: int = -1,
|
||||
binary_output: bool = False):
|
||||
binary_output: bool = False, ignore_labels=['O']):
|
||||
super().__init__(model=model,
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
@ -473,17 +473,12 @@ class NerPipeline(Pipeline):
|
||||
binary_output=binary_output)
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
self.ignore_labels = ignore_labels
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
inputs, answers = self._args_parser(*texts, **kwargs), []
|
||||
for sentence in inputs:
|
||||
|
||||
# Ugly token to word idx mapping (for now)
|
||||
token_to_word, words = [], self._basic_tokenizer.tokenize(sentence)
|
||||
for i, w in enumerate(words):
|
||||
tokens = self.tokenizer.tokenize(w)
|
||||
token_to_word += [i] * len(tokens)
|
||||
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
|
||||
@ -500,26 +495,22 @@ class NerPipeline(Pipeline):
|
||||
with torch.no_grad():
|
||||
entities = self.model(**tokens)[0][0].cpu().numpy()
|
||||
|
||||
# Normalize scores
|
||||
answer, token_start = [], 1
|
||||
for idx, word in groupby(token_to_word):
|
||||
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
|
||||
labels_idx = score.argmax(axis=-1)
|
||||
|
||||
# Sum log prob over token, then normalize across labels
|
||||
score = np.exp(entities[token_start]) / np.exp(entities[token_start]).sum(-1, keepdims=True)
|
||||
label_idx = score.argmax()
|
||||
|
||||
if label_idx > 0:
|
||||
answer = []
|
||||
for idx, label_idx in enumerate(labels_idx):
|
||||
if self.model.config.id2label[label_idx] not in self.ignore_labels:
|
||||
answer += [{
|
||||
'word': words[idx],
|
||||
'score': score[label_idx].item(),
|
||||
'word': self.tokenizer.decode(tokens['input_ids'][0][idx].cpu().tolist()),
|
||||
'score': score[idx][label_idx].item(),
|
||||
'entity': self.model.config.id2label[label_idx]
|
||||
}]
|
||||
|
||||
# Update token start
|
||||
token_start += len(list(word))
|
||||
|
||||
# Append
|
||||
answers += [answer]
|
||||
if len(answers) == 1:
|
||||
return answers[0]
|
||||
return answers
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user