Fix test_eos_token_id_int_and_list_top_k_top_sampling (#22826)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-04-18 16:04:51 +02:00 committed by GitHub
parent 1ebc1dee92
commit 90247d3e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2515,12 +2515,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
# Only some seeds will work both on CPU/GPU for a fixed `expectation` value.
# The selected seed is not guaranteed to work on all torch versions.
torch.manual_seed(1)
eos_token_id = 846
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
torch.manual_seed(1)
eos_token_id = [846, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))