diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 0c6b4702a43..7ade958149a 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -13,6 +13,7 @@ # limitations under the License. +import copy import json import os import platform @@ -28,7 +29,13 @@ from typing import Optional import yaml from huggingface_hub.utils import disable_progress_bars -from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer, logging +from transformers import ( + AutoTokenizer, + GenerationConfig, + PreTrainedTokenizer, + TextIteratorStreamer, + logging, +) from transformers.utils import is_rich_available, is_torch_available from . import BaseTransformersCLICommand @@ -45,7 +52,7 @@ if is_rich_available(): if is_torch_available(): import torch - from transformers import AutoModelForCausalLM, BitsAndBytesConfig + from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace) @@ -248,7 +255,9 @@ class ChatArguments: repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."}) eos_tokens: Optional[str] = field( default=None, - metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."}, + metadata={ + "help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated." + }, ) eos_token_ids: Optional[str] = field( default=None, @@ -469,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand): return processed_generate_flags def get_generation_parameterization( - self, args: ChatArguments, tokenizer: AutoTokenizer + self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel ) -> tuple[GenerationConfig, dict]: """ Returns a GenerationConfig object holding the generation parameters for the CLI command. """ - # No generation config arg provided -> use base generation config, apply CLI defaults + # No generation config arg provided -> use default generation config, apply CLI defaults if args.generation_config is None: - generation_config = GenerationConfig() + # We start off from the checkpoint's generation config + generation_config = copy.deepcopy(model.generation_config) # Apply deprecated CLI args on top of the default generation config - pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + pad_token_id, eos_token_ids = self.parse_eos_tokens( + tokenizer, generation_config, args.eos_tokens, args.eos_token_ids + ) deprecated_kwargs = { "max_new_tokens": args.max_new_tokens, "do_sample": args.do_sample, @@ -509,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand): @staticmethod def parse_eos_tokens( - tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str] + tokenizer: PreTrainedTokenizer, + generation_config: GenerationConfig, + eos_tokens: Optional[str], + eos_token_ids: Optional[str], ) -> tuple[int, list[int]]: """Retrieves the pad token ID and all possible EOS token IDs.""" - if tokenizer.pad_token_id is None: - pad_token_id = tokenizer.eos_token_id + if generation_config.pad_token_id is None: + pad_token_id = generation_config.eos_token_id else: - pad_token_id = tokenizer.pad_token_id + pad_token_id = generation_config.pad_token_id all_eos_token_ids = [] @@ -526,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand): all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) if len(all_eos_token_ids) == 0: - all_eos_token_ids.append(tokenizer.eos_token_id) + all_eos_token_ids.append(generation_config.eos_token_id) return pad_token_id, all_eos_token_ids @@ -683,7 +698,7 @@ class ChatCommand(BaseTransformersCLICommand): model, tokenizer = self.load_model_and_tokenizer(args) generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) - generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer) + generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model) # if not verbose -> disable warnings, progress bars, etc in the chat interface if not args.verbose: