mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Merge pull request #1313 from enzoampil/master
Add option to use a 'stop token'
This commit is contained in:
commit
81a1e12469
@ -152,6 +152,8 @@ def main():
|
|||||||
help="Avoid using CUDA when available")
|
help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
parser.add_argument('--seed', type=int, default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--stop_token', type=str, default=None,
|
||||||
|
help="Token at which text generation is stopped")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
@ -204,7 +206,10 @@ def main():
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
)
|
)
|
||||||
out = out[0, len(context_tokens):].tolist()
|
out = out[0, len(context_tokens):].tolist()
|
||||||
|
|
||||||
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
||||||
|
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
if args.prompt:
|
if args.prompt:
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user