mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[FIX] fix run_generation.py to work with batch_size > 1
This commit is contained in:
parent
7c0f2d0a6a
commit
a9f24a16bc
@ -81,7 +81,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
|||||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
"""
|
"""
|
||||||
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
|
||||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||||
if top_k > 0:
|
if top_k > 0:
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
@ -98,7 +97,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
|||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -122,10 +122,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
|||||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||||
|
|
||||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||||
next_token_logits = outputs[0][0, -1, :] / temperature
|
next_token_logits = outputs[0][:, -1, :] / temperature
|
||||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||||
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
generated = torch.cat((generated, next_token), dim=1)
|
||||||
return generated
|
return generated
|
||||||
|
|
||||||
|
|
||||||
@ -139,6 +139,7 @@ def main():
|
|||||||
parser.add_argument("--padding_text", type=str, default="")
|
parser.add_argument("--padding_text", type=str, default="")
|
||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--num_samples", type=int, default=1)
|
||||||
parser.add_argument("--top_k", type=int, default=0)
|
parser.add_argument("--top_k", type=int, default=0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument("--no_cuda", action='store_true',
|
||||||
@ -176,6 +177,7 @@ def main():
|
|||||||
out = sample_sequence(
|
out = sample_sequence(
|
||||||
model=model,
|
model=model,
|
||||||
context=context_tokens,
|
context=context_tokens,
|
||||||
|
num_samples=args.num_samples,
|
||||||
length=args.length,
|
length=args.length,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@ -183,9 +185,10 @@ def main():
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
is_xlnet=bool(args.model_type == "xlnet"),
|
is_xlnet=bool(args.model_type == "xlnet"),
|
||||||
)
|
)
|
||||||
out = out[0, len(context_tokens):].tolist()
|
out = out[:, len(context_tokens):].tolist()
|
||||||
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
|
for o in out:
|
||||||
print(text)
|
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
|
||||||
|
print(text)
|
||||||
if args.prompt:
|
if args.prompt:
|
||||||
break
|
break
|
||||||
return text
|
return text
|
||||||
|
Loading…
Reference in New Issue
Block a user