[docs] no hard-coding cuda (#36043)

make device-agnostic
This commit is contained in:
Fanli Lin 2025-02-06 00:22:33 +08:00 committed by GitHub
parent 7399f8021e
commit 531d1511f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,15 +57,15 @@ More concretely, key-value cache acts as a memory bank for these generative mode
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
>>> model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> past_key_values = DynamicCache()
>>> messages = [{"role": "user", "content": "Hello, what's your name."}]
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
>>> generated_ids = inputs.input_ids
>>> cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
>>> cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
>>> max_new_tokens = 10
>>> for _ in range(max_new_tokens):
@ -139,7 +139,7 @@ Cache quantization can be detrimental in terms of latency if the context length
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16).to("cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
@ -168,7 +168,7 @@ Use `cache_implementation="offloaded_static"` for an offloaded static cache (see
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
@ -278,7 +278,7 @@ Note that you can use this cache only for models that support sliding window, e.
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
>>> tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
>>> model = AutoModelForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to("cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Yesterday I was on a rock concert and.", return_tensors="pt").to(model.device)
>>> # can be used by passing in cache implementation
@ -298,7 +298,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
>>> tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16).to("cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)
>>> # get our cache, specify number of sink tokens and window size
@ -377,17 +377,19 @@ Sometimes you would want to first fill-in cache object with key/values for certa
>>> import copy
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
>>> from accelerate.test_utils.testing import get_backend
>>> DEVICE, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
>>> model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=DEVICE)
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
>>> # You can also init a DynamicCache, if that suits you better
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=DEVICE, dtype=torch.bfloat16)
>>> INITIAL_PROMPT = "You are a helpful assistant. "
>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(DEVICE)
>>> # This is the common prompt cached, we need to run forward without grad to be abel to copy
>>> with torch.no_grad():
... prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
@ -395,7 +397,7 @@ Sometimes you would want to first fill-in cache object with key/values for certa
>>> prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
>>> responses = []
>>> for prompt in prompts:
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(DEVICE)
... past_key_values = copy.deepcopy(prompt_cache)
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
... response = tokenizer.batch_decode(outputs)[0]