Patch T5 device test (#12742)

This commit is contained in:
Lysandre Debut 2021-07-15 17:40:17 +02:00 committed by GitHub
parent 370be9cc38
commit f42d9dcc0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -802,7 +802,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device)
sequences = model.generate(input_ids)