diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index e976ee9b187..2a2a6bbe109 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -30,6 +30,9 @@ class TokenClassificationArgumentHandler(ArgumentHandler): """ def __call__(self, inputs: Union[str, list[str]], **kwargs): + is_split_into_words = kwargs.get("is_split_into_words", False) + delimiter = kwargs.get("delimiter", None) + if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: inputs = list(inputs) batch_size = len(inputs) @@ -37,7 +40,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler): inputs = [inputs] batch_size = 1 elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType): - return inputs, None + return inputs, is_split_into_words, None, delimiter else: raise ValueError("At least one input is required.") @@ -47,7 +50,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler): offset_mapping = [offset_mapping] if len(offset_mapping) != batch_size: raise ValueError("offset_mapping should have the same batch size as the input") - return inputs, offset_mapping + return inputs, is_split_into_words, offset_mapping, delimiter class AggregationStrategy(ExplicitEnum): @@ -135,6 +138,7 @@ class TokenClassificationPipeline(ChunkPipeline): def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs): super().__init__(*args, **kwargs) + self.check_model_type( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES if self.framework == "tf" @@ -151,9 +155,16 @@ class TokenClassificationPipeline(ChunkPipeline): ignore_subwords: Optional[bool] = None, aggregation_strategy: Optional[AggregationStrategy] = None, offset_mapping: Optional[list[tuple[int, int]]] = None, + is_split_into_words: Optional[bool] = False, stride: Optional[int] = None, + delimiter: Optional[str] = None, ): preprocess_params = {} + preprocess_params["is_split_into_words"] = is_split_into_words + + if is_split_into_words: + preprocess_params["delimiter"] = " " if delimiter is None else delimiter + if offset_mapping is not None: preprocess_params["offset_mapping"] = offset_mapping @@ -230,8 +241,9 @@ class TokenClassificationPipeline(ChunkPipeline): Classify each token of the text(s) given as inputs. Args: - inputs (`str` or `list[str]`): - One or several texts (or one list of texts) for token classification. + inputs (`str` or `List[str]`): + One or several texts (or one list of texts) for token classification. Can be pre-tokenized when + `is_split_into_words=True`. Return: A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the @@ -251,7 +263,11 @@ class TokenClassificationPipeline(ChunkPipeline): exists if the offsets are available within the tokenizer """ - _inputs, offset_mapping = self._args_parser(inputs, **kwargs) + _inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs) + kwargs["is_split_into_words"] = is_split_into_words + kwargs["delimiter"] = delimiter + if is_split_into_words and not all(isinstance(input, list) for input in inputs): + return super().__call__([inputs], **kwargs) if offset_mapping: kwargs["offset_mapping"] = offset_mapping @@ -260,14 +276,43 @@ class TokenClassificationPipeline(ChunkPipeline): def preprocess(self, sentence, offset_mapping=None, **preprocess_params): tokenizer_params = preprocess_params.pop("tokenizer_params", {}) truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False + + word_to_chars_map = None + is_split_into_words = preprocess_params["is_split_into_words"] + if is_split_into_words: + delimiter = preprocess_params["delimiter"] + if not isinstance(sentence, list): + raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.") + words = sentence + sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing + # This map will allows to convert back word => char indices + word_to_chars_map = [] + delimiter_len = len(delimiter) + char_offset = 0 + for word in words: + word_to_chars_map.append((char_offset, char_offset + len(word))) + char_offset += len(word) + delimiter_len + + # We use `words` as the actual input for the tokenizer + text_to_tokenize = words + tokenizer_params["is_split_into_words"] = True + else: + if not isinstance(sentence, str): + raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.") + text_to_tokenize = sentence + inputs = self.tokenizer( - sentence, + text_to_tokenize, return_tensors=self.framework, truncation=truncation, return_special_tokens_mask=True, return_offsets_mapping=self.tokenizer.is_fast, **tokenizer_params, ) + + if is_split_into_words and not self.tokenizer.is_fast: + raise ValueError("is_split_into_words=True is only supported with fast tokenizers.") + inputs.pop("overflow_to_sample_mapping", None) num_chunks = len(inputs["input_ids"]) @@ -278,8 +323,12 @@ class TokenClassificationPipeline(ChunkPipeline): model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()} if offset_mapping is not None: model_inputs["offset_mapping"] = offset_mapping + model_inputs["sentence"] = sentence if i == 0 else None model_inputs["is_last"] = i == num_chunks - 1 + if word_to_chars_map is not None: + model_inputs["word_ids"] = inputs.word_ids(i) + model_inputs["word_to_chars_map"] = word_to_chars_map yield model_inputs @@ -289,6 +338,9 @@ class TokenClassificationPipeline(ChunkPipeline): offset_mapping = model_inputs.pop("offset_mapping", None) sentence = model_inputs.pop("sentence") is_last = model_inputs.pop("is_last") + word_ids = model_inputs.pop("word_ids", None) + word_to_chars_map = model_inputs.pop("word_to_chars_map", None) + if self.framework == "tf": logits = self.model(**model_inputs)[0] else: @@ -301,6 +353,8 @@ class TokenClassificationPipeline(ChunkPipeline): "offset_mapping": offset_mapping, "sentence": sentence, "is_last": is_last, + "word_ids": word_ids, + "word_to_chars_map": word_to_chars_map, **model_inputs, } @@ -308,6 +362,10 @@ class TokenClassificationPipeline(ChunkPipeline): if ignore_labels is None: ignore_labels = ["O"] all_entities = [] + + # Get map from the first output, it's the same for all chunks + word_to_chars_map = all_outputs[0].get("word_to_chars_map") + for model_outputs in all_outputs: if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16): logits = model_outputs["logits"][0].to(torch.float32).numpy() @@ -320,6 +378,7 @@ class TokenClassificationPipeline(ChunkPipeline): model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None ) special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy() + word_ids = model_outputs.get("word_ids") maxes = np.max(logits, axis=-1, keepdims=True) shifted_exp = np.exp(logits - maxes) @@ -330,7 +389,14 @@ class TokenClassificationPipeline(ChunkPipeline): offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None pre_entities = self.gather_pre_entities( - sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy + sentence, + input_ids, + scores, + offset_mapping, + special_tokens_mask, + aggregation_strategy, + word_ids=word_ids, + word_to_chars_map=word_to_chars_map, ) grouped_entities = self.aggregate(pre_entities, aggregation_strategy) # Filter anything that is in self.ignore_labels @@ -374,6 +440,8 @@ class TokenClassificationPipeline(ChunkPipeline): offset_mapping: Optional[list[tuple[int, int]]], special_tokens_mask: np.ndarray, aggregation_strategy: AggregationStrategy, + word_ids: Optional[list[Optional[int]]] = None, + word_to_chars_map: Optional[list[tuple[int, int]]] = None, ) -> list[dict]: """Fuse various numpy arrays into dicts with all the information needed for aggregation""" pre_entities = [] @@ -385,6 +453,15 @@ class TokenClassificationPipeline(ChunkPipeline): word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) if offset_mapping is not None: start_ind, end_ind = offset_mapping[idx] + + # If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence. + if word_ids is not None and word_to_chars_map is not None: + word_index = word_ids[idx] + if word_index is not None: + start_char, _ = word_to_chars_map[word_index] + start_ind += start_char + end_ind += start_char + if not isinstance(start_ind, int): if self.framework == "pt": start_ind = start_ind.item() diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 5344ff980d4..643e4d6675d 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -308,6 +308,54 @@ class TokenClassificationPipelineTests(unittest.TestCase): ], ) + @require_torch + @slow + def test_is_split_into_words(self): + """ + Tests the pipeline with pre-tokenized inputs (is_split_into_words=True) + and validates that the character offsets are correct. + """ + token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple") + + # Input is a list of words + words = ["Hello", "Sarah", "lives", "in", "New", "York"] + + # The reconstructed sentence will be "Hello Sarah lives in New York" + # - "Sarah": starts at index 6, ends at 11 + # - "New York": starts at index 21, ends at 29 + + output = token_classifier(words, is_split_into_words=True) + + self.assertEqual( + nested_simplify(output), + [ + {"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11}, + {"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29}, + ], + ) + + # Also test batching with pre-tokenized inputs + words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"] + batch_output = token_classifier([words, words2], is_split_into_words=True) + + # Expected for second sentence ("My name is Wolfgang and I live in Berlin") + # - "Wolfgang": starts at 12, ends at 20 + # - "Berlin": starts at 36, ends at 42 + + self.assertEqual( + nested_simplify(batch_output), + [ + [ + {"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11}, + {"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29}, + ], + [ + {"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20}, + {"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42}, + ], + ], + ) + @require_torch def test_chunking_fast(self): # Note: We cannot run the test on "conflicts" on the chunking. @@ -953,19 +1001,24 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): def test_simple(self): string = "This is a simple input" - inputs, offset_mapping = self.args_parser(string) + inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string) self.assertEqual(inputs, [string]) + self.assertFalse(is_split_into_words) self.assertEqual(offset_mapping, None) - inputs, offset_mapping = self.args_parser([string, string]) + inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string]) self.assertEqual(inputs, [string, string]) + self.assertFalse(is_split_into_words) self.assertEqual(offset_mapping, None) - inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)]) + inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser( + string, offset_mapping=[(0, 1), (1, 2)] + ) self.assertEqual(inputs, [string]) + self.assertFalse(is_split_into_words) self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]]) - inputs, offset_mapping = self.args_parser( + inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser( [string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]] ) self.assertEqual(inputs, [string, string])