diff --git a/examples/run_gpt2.py b/examples/run_gpt2.py index 03507474991..a30c6c6456a 100644 --- a/examples/run_gpt2.py +++ b/examples/run_gpt2.py @@ -83,7 +83,8 @@ def run_model(): elif args.length > model.config.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) - while not args.unconditional: + while True: + context_tokens = [] if not args.unconditional: raw_text = input("Model prompt >>> ") while not raw_text: @@ -106,6 +107,8 @@ def run_model(): print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text) print("=" * 80) + if args.unconditional: + break if __name__ == '__main__': run_model()