Merge pull request #1313 from enzoampil/master

Add option to use a 'stop token'
This commit is contained in:
Lysandre Debut 2019-10-03 22:43:57 +00:00 committed by GitHub
commit 81a1e12469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -152,6 +152,8 @@ def main():
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
@ -204,7 +206,10 @@ def main():
device=args.device, device=args.device,
) )
out = out[0, len(context_tokens):].tolist() out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True) text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
if args.prompt: if args.prompt:
break break