mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[FIX] fix run_generation.py to work with batch_size > 1
This commit is contained in:
parent
7c0f2d0a6a
commit
a9f24a16bc
@ -81,7 +81,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
@ -98,7 +97,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
@ -122,10 +122,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :] / temperature
|
||||
next_token_logits = outputs[0][:, -1, :] / temperature
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
return generated
|
||||
|
||||
|
||||
@ -139,6 +139,7 @@ def main():
|
||||
parser.add_argument("--padding_text", type=str, default="")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--num_samples", type=int, default=1)
|
||||
parser.add_argument("--top_k", type=int, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
@ -176,6 +177,7 @@ def main():
|
||||
out = sample_sequence(
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
num_samples=args.num_samples,
|
||||
length=args.length,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
@ -183,9 +185,10 @@ def main():
|
||||
device=args.device,
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
)
|
||||
out = out[0, len(context_tokens):].tolist()
|
||||
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
|
||||
print(text)
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for o in out:
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
||||
print(text)
|
||||
if args.prompt:
|
||||
break
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user