mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 14:50:07 +06:00

* token healing impl + trie with extensions * make fixup * prefix-robust space tokenization * examples readme and requirements * make fixup * allow input prompt and model * redundant defaults * Specialized Trie * make fixup * updated tests with new inherited Tree * input ids to auto device_map * rm unused import * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * naming convention * Revert "naming convention" This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0. * naming convention * last -hopefully- changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
import argparse
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
|
|
def generate(inputs, model, tokenizer, token_healing):
|
|
input_ids = tokenizer(inputs, return_tensors="pt", padding=True, device_map="auto").input_ids
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=8,
|
|
token_healing=token_healing,
|
|
pad_token_id=model.config.pad_token_id,
|
|
repetition_penalty=1.1,
|
|
)
|
|
output = model.generate(inputs=input_ids, generation_config=generation_config)
|
|
return tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--prompt", type=str)
|
|
parser.add_argument("--model_name_or_path", type=str, default="TheBloke/deepseek-llm-7B-base-GPTQ")
|
|
args = parser.parse_args()
|
|
|
|
prompts = (
|
|
[args.prompt]
|
|
if args.prompt
|
|
else [
|
|
'An example ["like this"] and another example [',
|
|
'The link is <a href="http:',
|
|
'The link is <a href="http', # test aggressive healing http->https
|
|
"I read a book about ", # test trailing whitespace
|
|
"I read a book about", # test nothing to heal
|
|
]
|
|
)
|
|
|
|
model_name_or_path = args.model_name_or_path
|
|
completion_model = AutoModelForCausalLM.from_pretrained(
|
|
model_name_or_path,
|
|
device_map="auto",
|
|
use_cache=True,
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
|
|
raw_output = generate(prompts, completion_model, tokenizer, token_healing=False)
|
|
healed_output = generate(prompts, completion_model, tokenizer, token_healing=True)
|
|
|
|
for p, a, b in zip(prompts, raw_output, healed_output):
|
|
print(f"\nPrompt: {p}\nWithout healing:\n{a}\nWith healing:\n{b}")
|
|
|
|
# You can also use token healing in isolation
|
|
# This can be useful if you have other work to do before the generation
|
|
# Or if you want to delegate generation to another process
|
|
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
|
|
healed_ids = completion_model.heal_tokens(input_ids)
|
|
healed_prompts = tokenizer.batch_decode(healed_ids, skip_special_tokens=True)
|
|
print("\nhealed prompts:")
|
|
for p in healed_prompts:
|
|
print(p)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|