Merge pull request #494 from SudoSharma/patch-1

Fix indentation for unconditional generation
This commit is contained in:
Thomas Wolf 2019-04-17 11:11:36 +02:00 committed by GitHub
commit 8407429d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -107,25 +107,25 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
break
if args.unconditional:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
break
if __name__ == '__main__':
run_model()