[FIX] fix run_generation.py to work with batch_size > 1

This commit is contained in:
mataney 2019-09-25 15:53:29 +03:00
parent 7c0f2d0a6a
commit a9f24a16bc

View File

@ -81,7 +81,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
@ -98,7 +97,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
@ -122,10 +122,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature
next_token_logits = outputs[0][:, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
generated = torch.cat((generated, next_token), dim=1)
return generated
@ -139,6 +139,7 @@ def main():
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
@ -176,6 +177,7 @@ def main():
out = sample_sequence(
model=model,
context=context_tokens,
num_samples=args.num_samples,
length=args.length,
temperature=args.temperature,
top_k=args.top_k,
@ -183,9 +185,10 @@ def main():
device=args.device,
is_xlnet=bool(args.model_type == "xlnet"),
)
out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
print(text)
out = out[:, len(context_tokens):].tolist()
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
print(text)
if args.prompt:
break
return text