diff --git a/examples/run_gpt2.py b/examples/run_gpt2.py index 03507474991..0289b267029 100644 --- a/examples/run_gpt2.py +++ b/examples/run_gpt2.py @@ -106,6 +106,23 @@ def run_model(): print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text) print("=" * 80) + if args.unconditional: + generated = 0 + for _ in range(args.nsamples // args.batch_size): + out = sample_sequence( + model=model, length=args.length, + context=None, + start_token=enc.encoder['<|endoftext|>'], + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + out = out[:,1:].tolist() + for i in range(args.batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) if __name__ == '__main__': run_model()