diff --git a/examples/run_generation.py b/examples/run_generation.py index b8cc8a9bbf8..2d917660cf7 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -79,7 +79,7 @@ def set_seed(args): def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: - logits: logits distribution shape (vocabulary size) + logits: logits distribution shape (batch size x vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) @@ -138,13 +138,14 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) - # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858) - for _ in set(generated.view(-1).tolist()): - next_token_logits[_] /= repetition_penalty + # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) + for i in range(num_samples): + for _ in set(generated[i].tolist()): + next_token_logits[i, _] /= repetition_penalty filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - if temperature == 0: #greedy sampling: - next_token = torch.argmax(filtered_logits).unsqueeze(0) + if temperature == 0: # greedy sampling: + next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token), dim=1)