fix assisted decoding (#31401)

* fix assisted decoding

* check None

* fix typo

* fix _prepare_special_tokens

* fix style

* fix lint

* add tests for assisted decoding

* fix style

* fix tests check
This commit is contained in:
jiqing-feng 2024-07-03 16:22:56 +08:00 committed by GitHub
parent f91c16d270
commit 7f91f168a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 1 deletions

View File

@ -1493,8 +1493,11 @@ class GenerationMixin:
device = self.device
token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_id = _tensor_or_none(

View File

@ -30,7 +30,9 @@ from transformers.testing_utils import (
require_auto_gptq,
require_quanto,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
slow,
torch_device,
)
@ -3097,6 +3099,54 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
@slow
@require_torch_multi_gpu
def test_assisted_decoding_in_different_gpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cuda:1"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]
out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cpu"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]
out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
@require_torch
class TokenHealingTestCase(unittest.TestCase):