mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix assisted decoding (#31401)
* fix assisted decoding * check None * fix typo * fix _prepare_special_tokens * fix style * fix lint * add tests for assisted decoding * fix style * fix tests check
This commit is contained in:
parent
f91c16d270
commit
7f91f168a1
@ -1493,8 +1493,11 @@ class GenerationMixin:
|
||||
device = self.device
|
||||
|
||||
token = token_kwargs if token_kwargs is not None else token_self
|
||||
if token is None or isinstance(token, torch.Tensor):
|
||||
if token is None:
|
||||
return token
|
||||
elif isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_id = _tensor_or_none(
|
||||
|
@ -30,7 +30,9 @@ from transformers.testing_utils import (
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -3097,6 +3099,54 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
def test_assisted_decoding_in_different_gpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
"cuda:1"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_assisted_decoding_in_gpu_cpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
"cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user