Bart example: model.to(device) (#3194)

This commit is contained in:
Sam Shleifer 2020-03-09 15:09:35 -04:00 committed by GitHub
parent 5164ea91a7
commit 3aca02efb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ def chunks(lst, n):
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,)
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")
for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)