[Feature] Support is_split_into_words in the TokenClassificationPipeline. (#38818)

* some fixes

* some fixes

* now the pipeline can take list of tokens as input and is_split_into_words argument

* now the pipeline can take list of tokens as input and is_split_into_words argument

* now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input

* now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input

* solving test problems

* some fixes

* some fixes

* modify tests

* aligning start and end correctly

* adding tests

* some formatting

* some formatting

* some fixes

* some fixes

* some fixes

* resolve conflicts

* removing unimportant lines

* removing unimportant lines

* generalize to other languages

* generalize to other languages

* generalize to other languages

* generalize to other languages
This commit is contained in:
Yusuf Shihata 2025-06-23 18:31:32 +03:00 committed by GitHub
parent 2ce02b98bf
commit 9eac19eb59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 141 additions and 11 deletions

View File

@ -30,6 +30,9 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
"""
def __call__(self, inputs: Union[str, list[str]], **kwargs):
is_split_into_words = kwargs.get("is_split_into_words", False)
delimiter = kwargs.get("delimiter", None)
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(inputs)
batch_size = len(inputs)
@ -37,7 +40,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
inputs = [inputs]
batch_size = 1
elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
return inputs, None
return inputs, is_split_into_words, None, delimiter
else:
raise ValueError("At least one input is required.")
@ -47,7 +50,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
offset_mapping = [offset_mapping]
if len(offset_mapping) != batch_size:
raise ValueError("offset_mapping should have the same batch size as the input")
return inputs, offset_mapping
return inputs, is_split_into_words, offset_mapping, delimiter
class AggregationStrategy(ExplicitEnum):
@ -135,6 +138,7 @@ class TokenClassificationPipeline(ChunkPipeline):
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
@ -151,9 +155,16 @@ class TokenClassificationPipeline(ChunkPipeline):
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
offset_mapping: Optional[list[tuple[int, int]]] = None,
is_split_into_words: Optional[bool] = False,
stride: Optional[int] = None,
delimiter: Optional[str] = None,
):
preprocess_params = {}
preprocess_params["is_split_into_words"] = is_split_into_words
if is_split_into_words:
preprocess_params["delimiter"] = " " if delimiter is None else delimiter
if offset_mapping is not None:
preprocess_params["offset_mapping"] = offset_mapping
@ -230,8 +241,9 @@ class TokenClassificationPipeline(ChunkPipeline):
Classify each token of the text(s) given as inputs.
Args:
inputs (`str` or `list[str]`):
One or several texts (or one list of texts) for token classification.
inputs (`str` or `List[str]`):
One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
`is_split_into_words=True`.
Return:
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
@ -251,7 +263,11 @@ class TokenClassificationPipeline(ChunkPipeline):
exists if the offsets are available within the tokenizer
"""
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
_inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs)
kwargs["is_split_into_words"] = is_split_into_words
kwargs["delimiter"] = delimiter
if is_split_into_words and not all(isinstance(input, list) for input in inputs):
return super().__call__([inputs], **kwargs)
if offset_mapping:
kwargs["offset_mapping"] = offset_mapping
@ -260,14 +276,43 @@ class TokenClassificationPipeline(ChunkPipeline):
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
word_to_chars_map = None
is_split_into_words = preprocess_params["is_split_into_words"]
if is_split_into_words:
delimiter = preprocess_params["delimiter"]
if not isinstance(sentence, list):
raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.")
words = sentence
sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing
# This map will allows to convert back word => char indices
word_to_chars_map = []
delimiter_len = len(delimiter)
char_offset = 0
for word in words:
word_to_chars_map.append((char_offset, char_offset + len(word)))
char_offset += len(word) + delimiter_len
# We use `words` as the actual input for the tokenizer
text_to_tokenize = words
tokenizer_params["is_split_into_words"] = True
else:
if not isinstance(sentence, str):
raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.")
text_to_tokenize = sentence
inputs = self.tokenizer(
sentence,
text_to_tokenize,
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
**tokenizer_params,
)
if is_split_into_words and not self.tokenizer.is_fast:
raise ValueError("is_split_into_words=True is only supported with fast tokenizers.")
inputs.pop("overflow_to_sample_mapping", None)
num_chunks = len(inputs["input_ids"])
@ -278,8 +323,12 @@ class TokenClassificationPipeline(ChunkPipeline):
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
if offset_mapping is not None:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence if i == 0 else None
model_inputs["is_last"] = i == num_chunks - 1
if word_to_chars_map is not None:
model_inputs["word_ids"] = inputs.word_ids(i)
model_inputs["word_to_chars_map"] = word_to_chars_map
yield model_inputs
@ -289,6 +338,9 @@ class TokenClassificationPipeline(ChunkPipeline):
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
is_last = model_inputs.pop("is_last")
word_ids = model_inputs.pop("word_ids", None)
word_to_chars_map = model_inputs.pop("word_to_chars_map", None)
if self.framework == "tf":
logits = self.model(**model_inputs)[0]
else:
@ -301,6 +353,8 @@ class TokenClassificationPipeline(ChunkPipeline):
"offset_mapping": offset_mapping,
"sentence": sentence,
"is_last": is_last,
"word_ids": word_ids,
"word_to_chars_map": word_to_chars_map,
**model_inputs,
}
@ -308,6 +362,10 @@ class TokenClassificationPipeline(ChunkPipeline):
if ignore_labels is None:
ignore_labels = ["O"]
all_entities = []
# Get map from the first output, it's the same for all chunks
word_to_chars_map = all_outputs[0].get("word_to_chars_map")
for model_outputs in all_outputs:
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
logits = model_outputs["logits"][0].to(torch.float32).numpy()
@ -320,6 +378,7 @@ class TokenClassificationPipeline(ChunkPipeline):
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
)
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
word_ids = model_outputs.get("word_ids")
maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
@ -330,7 +389,14 @@ class TokenClassificationPipeline(ChunkPipeline):
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
sentence,
input_ids,
scores,
offset_mapping,
special_tokens_mask,
aggregation_strategy,
word_ids=word_ids,
word_to_chars_map=word_to_chars_map,
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
@ -374,6 +440,8 @@ class TokenClassificationPipeline(ChunkPipeline):
offset_mapping: Optional[list[tuple[int, int]]],
special_tokens_mask: np.ndarray,
aggregation_strategy: AggregationStrategy,
word_ids: Optional[list[Optional[int]]] = None,
word_to_chars_map: Optional[list[tuple[int, int]]] = None,
) -> list[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
@ -385,6 +453,15 @@ class TokenClassificationPipeline(ChunkPipeline):
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
# If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
if word_ids is not None and word_to_chars_map is not None:
word_index = word_ids[idx]
if word_index is not None:
start_char, _ = word_to_chars_map[word_index]
start_ind += start_char
end_ind += start_char
if not isinstance(start_ind, int):
if self.framework == "pt":
start_ind = start_ind.item()

View File

@ -308,6 +308,54 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
)
@require_torch
@slow
def test_is_split_into_words(self):
"""
Tests the pipeline with pre-tokenized inputs (is_split_into_words=True)
and validates that the character offsets are correct.
"""
token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
# Input is a list of words
words = ["Hello", "Sarah", "lives", "in", "New", "York"]
# The reconstructed sentence will be "Hello Sarah lives in New York"
# - "Sarah": starts at index 6, ends at 11
# - "New York": starts at index 21, ends at 29
output = token_classifier(words, is_split_into_words=True)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
],
)
# Also test batching with pre-tokenized inputs
words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"]
batch_output = token_classifier([words, words2], is_split_into_words=True)
# Expected for second sentence ("My name is Wolfgang and I live in Berlin")
# - "Wolfgang": starts at 12, ends at 20
# - "Berlin": starts at 36, ends at 42
self.assertEqual(
nested_simplify(batch_output),
[
[
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
],
[
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
],
],
)
@require_torch
def test_chunking_fast(self):
# Note: We cannot run the test on "conflicts" on the chunking.
@ -953,19 +1001,24 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
def test_simple(self):
string = "This is a simple input"
inputs, offset_mapping = self.args_parser(string)
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string)
self.assertEqual(inputs, [string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser([string, string])
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string])
self.assertEqual(inputs, [string, string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
string, offset_mapping=[(0, 1), (1, 2)]
)
self.assertEqual(inputs, [string])
self.assertFalse(is_split_into_words)
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
inputs, offset_mapping = self.args_parser(
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
)
self.assertEqual(inputs, [string, string])