[marian tests ] pass device to pipeline (#4815)

This commit is contained in:
Sam Shleifer 2020-06-06 00:52:17 -04:00 committed by GitHub
parent ddf9a3dfc7
commit c58e6c129a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
self.tokenizer.prepare_translation_batch([""])
def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt")
device = 0 if torch_device == "cuda" else -1
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device)
output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])