mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 04:10:06 +06:00
[Feature] Support is_split_into_words
in the TokenClassificationPipeline
. (#38818)
* some fixes * some fixes * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * solving test problems * some fixes * some fixes * modify tests * aligning start and end correctly * adding tests * some formatting * some formatting * some fixes * some fixes * some fixes * resolve conflicts * removing unimportant lines * removing unimportant lines * generalize to other languages * generalize to other languages * generalize to other languages * generalize to other languages
This commit is contained in:
parent
2ce02b98bf
commit
9eac19eb59
@ -30,6 +30,9 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
||||
"""
|
||||
|
||||
def __call__(self, inputs: Union[str, list[str]], **kwargs):
|
||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||
delimiter = kwargs.get("delimiter", None)
|
||||
|
||||
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
|
||||
inputs = list(inputs)
|
||||
batch_size = len(inputs)
|
||||
@ -37,7 +40,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
||||
inputs = [inputs]
|
||||
batch_size = 1
|
||||
elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
|
||||
return inputs, None
|
||||
return inputs, is_split_into_words, None, delimiter
|
||||
else:
|
||||
raise ValueError("At least one input is required.")
|
||||
|
||||
@ -47,7 +50,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
|
||||
offset_mapping = [offset_mapping]
|
||||
if len(offset_mapping) != batch_size:
|
||||
raise ValueError("offset_mapping should have the same batch size as the input")
|
||||
return inputs, offset_mapping
|
||||
return inputs, is_split_into_words, offset_mapping, delimiter
|
||||
|
||||
|
||||
class AggregationStrategy(ExplicitEnum):
|
||||
@ -135,6 +138,7 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
|
||||
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
||||
if self.framework == "tf"
|
||||
@ -151,9 +155,16 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
ignore_subwords: Optional[bool] = None,
|
||||
aggregation_strategy: Optional[AggregationStrategy] = None,
|
||||
offset_mapping: Optional[list[tuple[int, int]]] = None,
|
||||
is_split_into_words: Optional[bool] = False,
|
||||
stride: Optional[int] = None,
|
||||
delimiter: Optional[str] = None,
|
||||
):
|
||||
preprocess_params = {}
|
||||
preprocess_params["is_split_into_words"] = is_split_into_words
|
||||
|
||||
if is_split_into_words:
|
||||
preprocess_params["delimiter"] = " " if delimiter is None else delimiter
|
||||
|
||||
if offset_mapping is not None:
|
||||
preprocess_params["offset_mapping"] = offset_mapping
|
||||
|
||||
@ -230,8 +241,9 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
Classify each token of the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
inputs (`str` or `list[str]`):
|
||||
One or several texts (or one list of texts) for token classification.
|
||||
inputs (`str` or `List[str]`):
|
||||
One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
|
||||
`is_split_into_words=True`.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
|
||||
@ -251,7 +263,11 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
exists if the offsets are available within the tokenizer
|
||||
"""
|
||||
|
||||
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
|
||||
_inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs)
|
||||
kwargs["is_split_into_words"] = is_split_into_words
|
||||
kwargs["delimiter"] = delimiter
|
||||
if is_split_into_words and not all(isinstance(input, list) for input in inputs):
|
||||
return super().__call__([inputs], **kwargs)
|
||||
if offset_mapping:
|
||||
kwargs["offset_mapping"] = offset_mapping
|
||||
|
||||
@ -260,14 +276,43 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
|
||||
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
|
||||
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
|
||||
|
||||
word_to_chars_map = None
|
||||
is_split_into_words = preprocess_params["is_split_into_words"]
|
||||
if is_split_into_words:
|
||||
delimiter = preprocess_params["delimiter"]
|
||||
if not isinstance(sentence, list):
|
||||
raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.")
|
||||
words = sentence
|
||||
sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing
|
||||
# This map will allows to convert back word => char indices
|
||||
word_to_chars_map = []
|
||||
delimiter_len = len(delimiter)
|
||||
char_offset = 0
|
||||
for word in words:
|
||||
word_to_chars_map.append((char_offset, char_offset + len(word)))
|
||||
char_offset += len(word) + delimiter_len
|
||||
|
||||
# We use `words` as the actual input for the tokenizer
|
||||
text_to_tokenize = words
|
||||
tokenizer_params["is_split_into_words"] = True
|
||||
else:
|
||||
if not isinstance(sentence, str):
|
||||
raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.")
|
||||
text_to_tokenize = sentence
|
||||
|
||||
inputs = self.tokenizer(
|
||||
sentence,
|
||||
text_to_tokenize,
|
||||
return_tensors=self.framework,
|
||||
truncation=truncation,
|
||||
return_special_tokens_mask=True,
|
||||
return_offsets_mapping=self.tokenizer.is_fast,
|
||||
**tokenizer_params,
|
||||
)
|
||||
|
||||
if is_split_into_words and not self.tokenizer.is_fast:
|
||||
raise ValueError("is_split_into_words=True is only supported with fast tokenizers.")
|
||||
|
||||
inputs.pop("overflow_to_sample_mapping", None)
|
||||
num_chunks = len(inputs["input_ids"])
|
||||
|
||||
@ -278,8 +323,12 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
|
||||
if offset_mapping is not None:
|
||||
model_inputs["offset_mapping"] = offset_mapping
|
||||
|
||||
model_inputs["sentence"] = sentence if i == 0 else None
|
||||
model_inputs["is_last"] = i == num_chunks - 1
|
||||
if word_to_chars_map is not None:
|
||||
model_inputs["word_ids"] = inputs.word_ids(i)
|
||||
model_inputs["word_to_chars_map"] = word_to_chars_map
|
||||
|
||||
yield model_inputs
|
||||
|
||||
@ -289,6 +338,9 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
offset_mapping = model_inputs.pop("offset_mapping", None)
|
||||
sentence = model_inputs.pop("sentence")
|
||||
is_last = model_inputs.pop("is_last")
|
||||
word_ids = model_inputs.pop("word_ids", None)
|
||||
word_to_chars_map = model_inputs.pop("word_to_chars_map", None)
|
||||
|
||||
if self.framework == "tf":
|
||||
logits = self.model(**model_inputs)[0]
|
||||
else:
|
||||
@ -301,6 +353,8 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
"offset_mapping": offset_mapping,
|
||||
"sentence": sentence,
|
||||
"is_last": is_last,
|
||||
"word_ids": word_ids,
|
||||
"word_to_chars_map": word_to_chars_map,
|
||||
**model_inputs,
|
||||
}
|
||||
|
||||
@ -308,6 +362,10 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
if ignore_labels is None:
|
||||
ignore_labels = ["O"]
|
||||
all_entities = []
|
||||
|
||||
# Get map from the first output, it's the same for all chunks
|
||||
word_to_chars_map = all_outputs[0].get("word_to_chars_map")
|
||||
|
||||
for model_outputs in all_outputs:
|
||||
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
|
||||
logits = model_outputs["logits"][0].to(torch.float32).numpy()
|
||||
@ -320,6 +378,7 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
|
||||
)
|
||||
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
|
||||
word_ids = model_outputs.get("word_ids")
|
||||
|
||||
maxes = np.max(logits, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(logits - maxes)
|
||||
@ -330,7 +389,14 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
|
||||
|
||||
pre_entities = self.gather_pre_entities(
|
||||
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
|
||||
sentence,
|
||||
input_ids,
|
||||
scores,
|
||||
offset_mapping,
|
||||
special_tokens_mask,
|
||||
aggregation_strategy,
|
||||
word_ids=word_ids,
|
||||
word_to_chars_map=word_to_chars_map,
|
||||
)
|
||||
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
|
||||
# Filter anything that is in self.ignore_labels
|
||||
@ -374,6 +440,8 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
offset_mapping: Optional[list[tuple[int, int]]],
|
||||
special_tokens_mask: np.ndarray,
|
||||
aggregation_strategy: AggregationStrategy,
|
||||
word_ids: Optional[list[Optional[int]]] = None,
|
||||
word_to_chars_map: Optional[list[tuple[int, int]]] = None,
|
||||
) -> list[dict]:
|
||||
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
|
||||
pre_entities = []
|
||||
@ -385,6 +453,15 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
|
||||
if offset_mapping is not None:
|
||||
start_ind, end_ind = offset_mapping[idx]
|
||||
|
||||
# If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
|
||||
if word_ids is not None and word_to_chars_map is not None:
|
||||
word_index = word_ids[idx]
|
||||
if word_index is not None:
|
||||
start_char, _ = word_to_chars_map[word_index]
|
||||
start_ind += start_char
|
||||
end_ind += start_char
|
||||
|
||||
if not isinstance(start_ind, int):
|
||||
if self.framework == "pt":
|
||||
start_ind = start_ind.item()
|
||||
|
@ -308,6 +308,54 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_is_split_into_words(self):
|
||||
"""
|
||||
Tests the pipeline with pre-tokenized inputs (is_split_into_words=True)
|
||||
and validates that the character offsets are correct.
|
||||
"""
|
||||
token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
|
||||
|
||||
# Input is a list of words
|
||||
words = ["Hello", "Sarah", "lives", "in", "New", "York"]
|
||||
|
||||
# The reconstructed sentence will be "Hello Sarah lives in New York"
|
||||
# - "Sarah": starts at index 6, ends at 11
|
||||
# - "New York": starts at index 21, ends at 29
|
||||
|
||||
output = token_classifier(words, is_split_into_words=True)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||
],
|
||||
)
|
||||
|
||||
# Also test batching with pre-tokenized inputs
|
||||
words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"]
|
||||
batch_output = token_classifier([words, words2], is_split_into_words=True)
|
||||
|
||||
# Expected for second sentence ("My name is Wolfgang and I live in Berlin")
|
||||
# - "Wolfgang": starts at 12, ends at 20
|
||||
# - "Berlin": starts at 36, ends at 42
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(batch_output),
|
||||
[
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
|
||||
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_chunking_fast(self):
|
||||
# Note: We cannot run the test on "conflicts" on the chunking.
|
||||
@ -953,19 +1001,24 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
string = "This is a simple input"
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string)
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string)
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
inputs, offset_mapping = self.args_parser([string, string])
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string])
|
||||
self.assertEqual(inputs, [string, string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, None)
|
||||
|
||||
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
|
||||
string, offset_mapping=[(0, 1), (1, 2)]
|
||||
)
|
||||
self.assertEqual(inputs, [string])
|
||||
self.assertFalse(is_split_into_words)
|
||||
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
|
||||
|
||||
inputs, offset_mapping = self.args_parser(
|
||||
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
|
||||
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
|
||||
)
|
||||
self.assertEqual(inputs, [string, string])
|
||||
|
Loading…
Reference in New Issue
Block a user