transformers/examples/research_projects/token-healing/run_token_healing.py
Ahmed Moubtahij 39b2ff69d6
Token healing (#30081)
* token healing impl + trie with extensions

* make fixup

* prefix-robust space tokenization

* examples readme and requirements

* make fixup

* allow input prompt and model

* redundant defaults

* Specialized Trie

* make fixup

* updated tests with new inherited Tree

* input ids to auto device_map

* rm unused import

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* naming convention

* Revert "naming convention"

This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0.

* naming convention

* last -hopefully- changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-06-03 10:53:15 +02:00

63 lines
2.3 KiB
Python

import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def generate(inputs, model, tokenizer, token_healing):
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids
generation_config = GenerationConfig(
max_new_tokens=8,
token_healing=token_healing,
pad_token_id=model.config.pad_token_id,
repetition_penalty=1.1,
)
output = model.generate(inputs=input_ids, generation_config=generation_config)
return tokenizer.batch_decode(output, skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ")
args = parser.parse_args()
prompts = (
[args.prompt]
if args.prompt
else [
'An example ["like this"] and another example [',
'The link is <a href="http:',
'The link is <a href="http', # test aggressive healing http->https
"I read a book about ", # test trailing whitespace
"I read a book about", # test nothing to heal
]
)
model_name_or_path = args.model_name_or_path
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
use_cache=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
raw_output = generate(prompts, completion_model, tokenizer, token_healing=False)
healed_output = generate(prompts, completion_model, tokenizer, token_healing=True)
for p, a, b in zip(prompts, raw_output, healed_output):
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")
# You can also use token healing in isolation
# This can be useful if you have other work to do before the generation
# Or if you want to delegate generation to another process
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
healed_ids = completion_model.heal_tokens(input_ids)
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
print("\nhealed prompts:")
for p in healed_prompts:
print(p)
if __name__ == "__main__":
main()