mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
making unconditional generation work
The unconditional generation works now but if the seed is fixed, the sample is the same every time. n_samples > 1 will give different samples though. I am giving the start token as '<|endoftext|>' for the unconditional generation.
This commit is contained in:
parent
694e2117f3
commit
f872eb98c2
@ -106,6 +106,23 @@ def run_model():
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
if args.unconditional:
|
||||
generated = 0
|
||||
for _ in range(args.nsamples // args.batch_size):
|
||||
out = sample_sequence(
|
||||
model=model, length=args.length,
|
||||
context=None,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature, top_k=args.top_k, device=device
|
||||
)
|
||||
out = out[:,1:].tolist()
|
||||
for i in range(args.batch_size):
|
||||
generated += 1
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_model()
|
||||
|
Loading…
Reference in New Issue
Block a user