Token healing (#30081)

* token healing impl + trie with extensions

* make fixup

* prefix-robust space tokenization

* examples readme and requirements

* make fixup

* allow input prompt and model

* redundant defaults

* Specialized Trie

* make fixup

* updated tests with new inherited Tree

* input ids to auto device_map

* rm unused import

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* naming convention

* Revert "naming convention"

This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0.

* naming convention

* last -hopefully- changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Ahmed Moubtahij 2024-06-03 04:53:15 -04:00 committed by GitHub
parent 5b5b48b11d
commit 39b2ff69d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 324 additions and 5 deletions

View File

@ -0,0 +1,40 @@
<!-- back to top link -->
<a name="readme-top"></a>
<!-- ABOUT THE PROJECT -->
## What is token healing?
Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models.
Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models.
Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace.
A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38).
## Usage
```py
prompt = 'The link is <a href="http:'
raw_output = generate(prompt, completion_model, tokenizer, token_healing=False)
# The link is <a href="http:&#47;&#47;www&#47;dailymail&#
# The model saw '://' as a single token in training. Seeing a prompt ending with `:` tells it that the
# next token is likely not `//`, because otherwise it would've seen `://`.
# Thus, it completes with a token other than `//`, in this case, `&`.
healed_output = generate(prompt, completion_model, tokenizer, token_healing=True)
# The link is <a href="http://www.365doki.com/post/3699
# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(test_prompts, return_tensors='pt', padding=True).input_ids.cuda()
healed_ids = model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
# outputs the healed prompts without further completion/generation
```
See `run_token_healing.py` for the full example.
<p align="right">(<a href="#readme-top">back to top</a>)</p>

View File

@ -0,0 +1,62 @@
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def generate(inputs, model, tokenizer, token_healing):
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids
generation_config = GenerationConfig(
max_new_tokens=8,
token_healing=token_healing,
pad_token_id=model.config.pad_token_id,
repetition_penalty=1.1,
)
output = model.generate(inputs=input_ids, generation_config=generation_config)
return tokenizer.batch_decode(output, skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ")
args = parser.parse_args()
prompts = (
[args.prompt]
if args.prompt
else [
'An example ["like this"] and another example [',
'The link is <a href="http:',
'The link is <a href="http', # test aggressive healing http->https
"I read a book about ", # test trailing whitespace
"I read a book about", # test nothing to heal
]
)
model_name_or_path = args.model_name_or_path
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
use_cache=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
raw_output = generate(prompts, completion_model, tokenizer, token_healing=False)
healed_output = generate(prompts, completion_model, tokenizer, token_healing=True)
for p, a, b in zip(prompts, raw_output, healed_output):
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")
# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
healed_ids = completion_model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
print("\nhealed prompts:")
for p in healed_prompts:
print(p)
if __name__ == "__main__":
main()

View File

@ -222,6 +222,9 @@ class GenerationConfig(PushToHubMixin):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
token_healing (`bool`, *optional*, defaults to `False`):
Heal tail tokens of prompts by replacing them with their appropriate extensions.
This enhances the quality of completions for prompts affected by greedy tokenization bias.
guidance_scale (`float`, *optional*):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
@ -360,6 +363,7 @@ class GenerationConfig(PushToHubMixin):
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.token_healing = kwargs.pop("token_healing", False)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
watermarking_config = kwargs.pop("watermarking_config", None)

View File

@ -42,6 +42,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
ModelOutput,
is_accelerate_available,
@ -1591,6 +1592,8 @@ class GenerationMixin:
else:
synced_gpus = False
tokenizer = kwargs.pop("tokenizer", None)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
@ -1653,6 +1656,9 @@ class GenerationMixin:
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
@ -1989,6 +1995,75 @@ class GenerationMixin:
return False
return True
def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
) -> torch.LongTensor:
r"""
Generates sequences of token ids for models with a language modeling head.
Parameters:
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
Return:
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
"""
if tokenizer is None:
raise ValueError(
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
"argument of `generate`."
)
bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
input_ids = tokenizer(
prompts,
return_tensors="pt",
padding=True,
).input_ids.to(input_ids.device)
# replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
batch_ids = input_ids[batch_idx]
if torch.all(batch_ids == pad_token_id).item():
continue # skip empty sequences (all pad ids)
# apply bias for alternatives (extensions) to the tail token
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
seq_bias[(tail_id,)] += 1.0
generation_config.update(sequence_bias=seq_bias)
trimmed_ids = batch_ids[:-1]
# if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1:
trimmed_ids[-1] = bos_token_id
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
return input_ids
def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._contrastive_search(*args, **kwargs)
@torch.no_grad()
def _contrastive_search(
self,

View File

@ -56,14 +56,26 @@ class Trie:
Loose reference https://en.wikipedia.org/wiki/Trie
"""
def __init__(self):
def __init__(self, *args):
self.data = {}
self._tokens = set()
self._termination_char = ""
self.update(*args)
def update(self, *args):
"""
Updates the Trie with new tokens provided as arguments.
Args:
*args: Variable number of words to be added to the Trie.
"""
for token in tuple(*args):
self.add(token)
def add(self, word: str):
"""
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
The special key `""` is used to represent termination.
The special key `""` in `self._termination_char` is used to represent termination.
This function is idempotent, adding twice the same word will leave the trie unchanged
@ -87,9 +99,9 @@ class Trie:
self._tokens.add(word)
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref[char] = ref.setdefault(char, {})
ref = ref[char]
ref[""] = 1
ref[self._termination_char] = 1
def split(self, text: str) -> List[str]:
"""
@ -269,6 +281,62 @@ class Trie:
return tokens
class ExtensionsTrie(Trie):
def __init__(self, *args):
super().__init__(*args)
def extensions(self, prefix: str):
"""
Generates all extensions of a given prefix token in the Trie.
Example:
```python
>>> trie = Trie()
>>> trie.add("apple")
>>> trie.add("app")
>>> trie.add("application")
>>> trie.extensions("app")
['app', 'apple', 'application']
```
"""
prefix_node = self._get_node(prefix)
ret = self._collect_tokens(prefix_node)
return [prefix + token for token in ret]
def _get_node(self, token: str) -> dict:
"""
Retrieves the node corresponding to the given token in the Trie.
Args:
token (str): The token for which the corresponding node needs to be retrieved.
Returns:
dict: The node in the Trie corresponding to the given token.
"""
node = self.data
for char in token:
node = node[char]
return node
def _collect_tokens(self, node: dict) -> list:
"""
Generates all tokens in the Trie starting from a given node.
Args:
node (dict): The node in the Trie from which tokens need to be generated.
Returns:
list: List of tokens generated from the given node.
"""
tokens = [self._termination_char] if self._termination_char in node else []
for token, subtrie_head in node.items():
if token != self._termination_char:
subtokens = self._collect_tokens(subtrie_head)
tokens.extend([token + subtoken for subtoken in subtokens])
return tokens
def _is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them

View File

@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_auto_gptq,
require_quanto,
require_torch,
require_torch_multi_accelerator,
@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
@require_torch
class TokenHealingTestCase(unittest.TestCase):
@parameterized.expand(
[
(
"square_bracket",
'An example ["like this"] and another example [',
'An example ["like this"] and another example ["',
),
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
# aggressive_healing: "http" shouldn't be replaced with "https"
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
("trailing_whitespace", "I read a book about ", "I read a book about"),
("nothing_to_heal", "I read a book about", "I read a book about"),
("single_token", "I", "I"),
("empty_prompt", "", ""),
]
)
@require_auto_gptq
def test_prompts(self, name, input, expected):
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main",
use_cache=True,
)
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
healed_ids = completion_model.heal_tokens(input_ids)
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
self.assertEqual(predicted, expected)
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)

View File

@ -33,7 +33,7 @@ from transformers import (
is_tokenizers_available,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
from transformers.tokenization_utils import Trie
from transformers.tokenization_utils import ExtensionsTrie, Trie
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
class ExtensionsTrieTest(unittest.TestCase):
def test_extensions(self):
# Test searching by prefix
trie = ExtensionsTrie()
trie.add("foo")
trie.add("food")
trie.add("foodie")
trie.add("helium")
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
self.assertEqual(trie.extensions("helium"), ["helium"])
def test_empty_prefix(self):
trie = ExtensionsTrie()
# Test searching with an empty prefix returns all values
trie.add("hello")
trie.add("bye")
self.assertEqual(trie.extensions(""), ["hello", "bye"])
def test_no_extension_match(self):
trie = ExtensionsTrie()
# Test searching for a prefix that doesn't match any key
with self.assertRaises(KeyError):
trie.extensions("unknown")
def test_update_value(self):
trie = ExtensionsTrie()
# Test updating the value of an existing key
trie.add("hi")
trie.add("hi")
self.assertEqual(trie.extensions("hi"), ["hi"])