mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Token healing (#30081)
* token healing impl + trie with extensions * make fixup * prefix-robust space tokenization * examples readme and requirements * make fixup * allow input prompt and model * redundant defaults * Specialized Trie * make fixup * updated tests with new inherited Tree * input ids to auto device_map * rm unused import * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * naming convention * Revert "naming convention" This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0. * naming convention * last -hopefully- changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
5b5b48b11d
commit
39b2ff69d6
40
examples/research_projects/token-healing/README.md
Normal file
40
examples/research_projects/token-healing/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
<!-- back to top link -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<!-- ABOUT THE PROJECT -->
|
||||
## What is token healing?
|
||||
|
||||
Token healing rectifies the token boundary bias in greedy tokenization. It does this by trimming and regrowing the prompt to better align with the model's tokenizer, thus enhancing generation quality. The improvement is clearest with completion models.
|
||||
|
||||
Example: given a completion prompt with a partial url ending with `:`, the model might have seen the expected completion `://` as a _single_ token in training. However, the prompt's tail token `:` tells it that the next token is not `//`, and so it looks for wrong completions. Such errors compound in auto-regressive language models.
|
||||
|
||||
Debiasing token boundaries also addresses output sensitivity to prompts ending with whitespace.
|
||||
|
||||
A more thorough explanation can be found on [The Art of Prompt Design: Prompt Boundaries and Token Healing | by Scott Lundberg](https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38).
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
prompt = 'The link is <a href="http:'
|
||||
raw_output = generate(prompt, completion_model, tokenizer, token_healing=False)
|
||||
# The link is <a href="http://www/dailymail&#
|
||||
|
||||
# The model saw '://' as a single token in training. Seeing a prompt ending with `:` tells it that the
|
||||
# next token is likely not `//`, because otherwise it would've seen `://`.
|
||||
# Thus, it completes with a token other than `//`, in this case, `&`.
|
||||
|
||||
healed_output = generate(prompt, completion_model, tokenizer, token_healing=True)
|
||||
# The link is <a href="http://www.365doki.com/post/3699
|
||||
|
||||
# You can also use token healing in isolation
|
||||
# This can be useful if you have other work to do before the generation
|
||||
# Or if you want to delegate generation to another process
|
||||
input_ids = tokenizer(test_prompts, return_tensors='pt', padding=True).input_ids.cuda()
|
||||
healed_ids = model.heal_tokens(input_ids)
|
||||
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
|
||||
# outputs the healed prompts without further completion/generation
|
||||
```
|
||||
|
||||
See `run_token_healing.py` for the full example.
|
||||
|
||||
<p align="right">(<a href="#readme-top">back to top</a>)</p>
|
@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
|
||||
def generate(inputs, model, tokenizer, token_healing):
|
||||
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=8,
|
||||
token_healing=token_healing,
|
||||
pad_token_id=model.config.pad_token_id,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
output = model.generate(inputs=input_ids, generation_config=generation_config)
|
||||
return tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str)
|
||||
parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ")
|
||||
args = parser.parse_args()
|
||||
|
||||
prompts = (
|
||||
[args.prompt]
|
||||
if args.prompt
|
||||
else [
|
||||
'An example ["like this"] and another example [',
|
||||
'The link is <a href="http:',
|
||||
'The link is <a href="http', # test aggressive healing http->https
|
||||
"I read a book about ", # test trailing whitespace
|
||||
"I read a book about", # test nothing to heal
|
||||
]
|
||||
)
|
||||
|
||||
model_name_or_path = args.model_name_or_path
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map="auto",
|
||||
use_cache=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
raw_output = generate(prompts, completion_model, tokenizer, token_healing=False)
|
||||
healed_output = generate(prompts, completion_model, tokenizer, token_healing=True)
|
||||
|
||||
for p, a, b in zip(prompts, raw_output, healed_output):
|
||||
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")
|
||||
|
||||
# You can also use token healing in isolation
|
||||
# This can be useful if you have other work to do before the generation
|
||||
# Or if you want to delegate generation to another process
|
||||
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
|
||||
print("\nhealed prompts:")
|
||||
for p in healed_prompts:
|
||||
print(p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -222,6 +222,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. Check
|
||||
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
||||
token_healing (`bool`, *optional*, defaults to `False`):
|
||||
Heal tail tokens of prompts by replacing them with their appropriate extensions.
|
||||
This enhances the quality of completions for prompts affected by greedy tokenization bias.
|
||||
guidance_scale (`float`, *optional*):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
@ -360,6 +363,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||
self.token_healing = kwargs.pop("token_healing", False)
|
||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||
self.low_memory = kwargs.pop("low_memory", None)
|
||||
watermarking_config = kwargs.pop("watermarking_config", None)
|
||||
|
@ -42,6 +42,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
from ..utils import (
|
||||
ModelOutput,
|
||||
is_accelerate_available,
|
||||
@ -1591,6 +1592,8 @@ class GenerationMixin:
|
||||
else:
|
||||
synced_gpus = False
|
||||
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
@ -1653,6 +1656,9 @@ class GenerationMixin:
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
if generation_config.token_healing:
|
||||
input_ids = self.heal_tokens(input_ids, tokenizer)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
@ -1989,6 +1995,75 @@ class GenerationMixin:
|
||||
return False
|
||||
return True
|
||||
|
||||
def heal_tokens(
|
||||
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
|
||||
) -> torch.LongTensor:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
|
||||
tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
|
||||
Return:
|
||||
`torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
" When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
|
||||
"argument of `generate`."
|
||||
)
|
||||
|
||||
bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
|
||||
vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
|
||||
generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
|
||||
|
||||
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
|
||||
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
|
||||
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
|
||||
input_ids = tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).input_ids.to(input_ids.device)
|
||||
|
||||
# replace bos with pad to not condition healing on it
|
||||
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
|
||||
|
||||
tail_ids = input_ids[:, -1].tolist()
|
||||
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
|
||||
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
|
||||
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
|
||||
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
|
||||
|
||||
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
|
||||
batch_ids = input_ids[batch_idx]
|
||||
if torch.all(batch_ids == pad_token_id).item():
|
||||
continue # skip empty sequences (all pad ids)
|
||||
|
||||
# apply bias for alternatives (extensions) to the tail token
|
||||
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
|
||||
if len(seq_bias) == 1:
|
||||
continue # skip if there are no token alternatives to heal with
|
||||
|
||||
# slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
|
||||
seq_bias[(tail_id,)] += 1.0
|
||||
generation_config.update(sequence_bias=seq_bias)
|
||||
|
||||
trimmed_ids = batch_ids[:-1]
|
||||
# if the prompt is a single (non-pad) token, regenerate from bos
|
||||
if len(batch_ids[batch_ids != pad_token_id]) == 1:
|
||||
trimmed_ids[-1] = bos_token_id
|
||||
|
||||
input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
|
||||
|
||||
return input_ids
|
||||
|
||||
def contrastive_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._contrastive_search(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def _contrastive_search(
|
||||
self,
|
||||
|
@ -56,14 +56,26 @@ class Trie:
|
||||
Loose reference https://en.wikipedia.org/wiki/Trie
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, *args):
|
||||
self.data = {}
|
||||
self._tokens = set()
|
||||
self._termination_char = ""
|
||||
self.update(*args)
|
||||
|
||||
def update(self, *args):
|
||||
"""
|
||||
Updates the Trie with new tokens provided as arguments.
|
||||
|
||||
Args:
|
||||
*args: Variable number of words to be added to the Trie.
|
||||
"""
|
||||
for token in tuple(*args):
|
||||
self.add(token)
|
||||
|
||||
def add(self, word: str):
|
||||
"""
|
||||
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
|
||||
The special key `""` is used to represent termination.
|
||||
The special key `""` in `self._termination_char` is used to represent termination.
|
||||
|
||||
This function is idempotent, adding twice the same word will leave the trie unchanged
|
||||
|
||||
@ -87,9 +99,9 @@ class Trie:
|
||||
self._tokens.add(word)
|
||||
ref = self.data
|
||||
for char in word:
|
||||
ref[char] = char in ref and ref[char] or {}
|
||||
ref[char] = ref.setdefault(char, {})
|
||||
ref = ref[char]
|
||||
ref[""] = 1
|
||||
ref[self._termination_char] = 1
|
||||
|
||||
def split(self, text: str) -> List[str]:
|
||||
"""
|
||||
@ -269,6 +281,62 @@ class Trie:
|
||||
return tokens
|
||||
|
||||
|
||||
class ExtensionsTrie(Trie):
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
def extensions(self, prefix: str):
|
||||
"""
|
||||
Generates all extensions of a given prefix token in the Trie.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> trie = Trie()
|
||||
>>> trie.add("apple")
|
||||
>>> trie.add("app")
|
||||
>>> trie.add("application")
|
||||
>>> trie.extensions("app")
|
||||
['app', 'apple', 'application']
|
||||
```
|
||||
"""
|
||||
prefix_node = self._get_node(prefix)
|
||||
ret = self._collect_tokens(prefix_node)
|
||||
return [prefix + token for token in ret]
|
||||
|
||||
def _get_node(self, token: str) -> dict:
|
||||
"""
|
||||
Retrieves the node corresponding to the given token in the Trie.
|
||||
|
||||
Args:
|
||||
token (str): The token for which the corresponding node needs to be retrieved.
|
||||
|
||||
Returns:
|
||||
dict: The node in the Trie corresponding to the given token.
|
||||
"""
|
||||
node = self.data
|
||||
for char in token:
|
||||
node = node[char]
|
||||
return node
|
||||
|
||||
def _collect_tokens(self, node: dict) -> list:
|
||||
"""
|
||||
Generates all tokens in the Trie starting from a given node.
|
||||
|
||||
Args:
|
||||
node (dict): The node in the Trie from which tokens need to be generated.
|
||||
|
||||
Returns:
|
||||
list: List of tokens generated from the given node.
|
||||
"""
|
||||
tokens = [self._termination_char] if self._termination_char in node else []
|
||||
for token, subtrie_head in node.items():
|
||||
if token != self._termination_char:
|
||||
subtokens = self._collect_tokens(subtrie_head)
|
||||
tokens.extend([token + subtoken for subtoken in subtokens])
|
||||
return tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `char` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
|
@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_multi_accelerator,
|
||||
@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"square_bracket",
|
||||
'An example ["like this"] and another example [',
|
||||
'An example ["like this"] and another example ["',
|
||||
),
|
||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||
("trailing_whitespace", "I read a book about ", "I read a book about"),
|
||||
("nothing_to_heal", "I read a book about", "I read a book about"),
|
||||
("single_token", "I", "I"),
|
||||
("empty_prompt", "", ""),
|
||||
]
|
||||
)
|
||||
@require_auto_gptq
|
||||
def test_prompts(self, name, input, expected):
|
||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=False,
|
||||
revision="main",
|
||||
use_cache=True,
|
||||
)
|
||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(predicted, expected)
|
||||
|
||||
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
@ -33,7 +33,7 @@ from transformers import (
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
|
||||
from transformers.tokenization_utils import Trie
|
||||
from transformers.tokenization_utils import ExtensionsTrie, Trie
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
|
||||
class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_extensions(self):
|
||||
# Test searching by prefix
|
||||
trie = ExtensionsTrie()
|
||||
trie.add("foo")
|
||||
trie.add("food")
|
||||
trie.add("foodie")
|
||||
trie.add("helium")
|
||||
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
|
||||
self.assertEqual(trie.extensions("helium"), ["helium"])
|
||||
|
||||
def test_empty_prefix(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching with an empty prefix returns all values
|
||||
trie.add("hello")
|
||||
trie.add("bye")
|
||||
self.assertEqual(trie.extensions(""), ["hello", "bye"])
|
||||
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
with self.assertRaises(KeyError):
|
||||
trie.extensions("unknown")
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test updating the value of an existing key
|
||||
trie.add("hi")
|
||||
trie.add("hi")
|
||||
self.assertEqual(trie.extensions("hi"), ["hi"])
|
||||
|
Loading…
Reference in New Issue
Block a user