Correct quickstart example when using the past

This commit is contained in:
Lysandre 2020-02-10 11:25:56 -05:00
parent 63a5399bc4
commit fd639e5be3

View File

@ -209,7 +209,7 @@ past = None
for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[0, :])
token = torch.argmax(output[..., -1, :])
generated += [token.tolist()]
context = token.unsqueeze(0)