mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix test_eos_token_id_int_and_list_top_k_top_sampling
(#22826)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
1ebc1dee92
commit
90247d3e01
@ -2515,12 +2515,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
tokens = tokenizer(text, return_tensors="pt").to(torch_device)
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Only some seeds will work both on CPU/GPU for a fixed `expectation` value.
|
||||
# The selected seed is not guaranteed to work on all torch versions.
|
||||
torch.manual_seed(1)
|
||||
eos_token_id = 846
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.manual_seed(1)
|
||||
eos_token_id = [846, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
Loading…
Reference in New Issue
Block a user