From a9f24a16bc2965d2990b90127ed4b5a1f47344b9 Mon Sep 17 00:00:00 2001 From: mataney Date: Wed, 25 Sep 2019 15:53:29 +0300 Subject: [PATCH] [FIX] fix run_generation.py to work with batch_size > 1 --- examples/run_generation.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index a2a8f291031..935e578441d 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -81,7 +81,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ - assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k @@ -98,7 +97,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices[sorted_indices_to_remove] + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits @@ -122,10 +122,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) - next_token_logits = outputs[0][0, -1, :] / temperature + next_token_logits = outputs[0][:, -1, :] / temperature filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) + generated = torch.cat((generated, next_token), dim=1) return generated @@ -139,6 +139,7 @@ def main(): parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--no_cuda", action='store_true', @@ -176,6 +177,7 @@ def main(): out = sample_sequence( model=model, context=context_tokens, + num_samples=args.num_samples, length=args.length, temperature=args.temperature, top_k=args.top_k, @@ -183,9 +185,10 @@ def main(): device=args.device, is_xlnet=bool(args.model_type == "xlnet"), ) - out = out[0, len(context_tokens):].tolist() - text = tokenizer.decode(out, clean_up_tokenization_spaces=True) - print(text) + out = out[:, len(context_tokens):].tolist() + for o in out: + text = tokenizer.decode(o, clean_up_tokenization_spaces=True) + print(text) if args.prompt: break return text