diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 5c9bd76bdb0..04a1b11a21c 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -26,6 +26,7 @@ from typing import Optional import yaml +from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer from transformers.utils import is_rich_available, is_torch_available from . import BaseTransformersCLICommand @@ -42,13 +43,7 @@ if is_rich_available(): if is_torch_available(): import torch - from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GenerationConfig, - TextIteratorStreamer, - ) + from transformers import AutoModelForCausalLM, BitsAndBytesConfig ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace) @@ -547,7 +542,7 @@ class ChatCommand(BaseTransformersCLICommand): return quantization_config - def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]: + def load_model_and_tokenizer(self, args: ChatArguments) -> tuple["AutoModelForCausalLM", AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path_positional, revision=args.model_revision,