[chat] use the checkpoint's generation_config.json as base parameterization (#38330)

* use model gen config

* unwanted diff
This commit is contained in:
Joao Gante 2025-05-27 11:35:33 +01:00 committed by GitHub
parent 008e0d87c5
commit 80902ae9b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@
# limitations under the License.
import copy
import json
import os
import platform
@ -28,7 +29,13 @@ from typing import Optional
import yaml
from huggingface_hub.utils import disable_progress_bars
from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer, logging
from transformers import (
AutoTokenizer,
GenerationConfig,
PreTrainedTokenizer,
TextIteratorStreamer,
logging,
)
from transformers.utils import is_rich_available, is_torch_available
from . import BaseTransformersCLICommand
@ -45,7 +52,7 @@ if is_rich_available():
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
@ -248,7 +255,9 @@ class ChatArguments:
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
eos_tokens: Optional[str] = field(
default=None,
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
metadata={
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
},
)
eos_token_ids: Optional[str] = field(
default=None,
@ -469,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand):
return processed_generate_flags
def get_generation_parameterization(
self, args: ChatArguments, tokenizer: AutoTokenizer
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
) -> tuple[GenerationConfig, dict]:
"""
Returns a GenerationConfig object holding the generation parameters for the CLI command.
"""
# No generation config arg provided -> use base generation config, apply CLI defaults
# No generation config arg provided -> use default generation config, apply CLI defaults
if args.generation_config is None:
generation_config = GenerationConfig()
# We start off from the checkpoint's generation config
generation_config = copy.deepcopy(model.generation_config)
# Apply deprecated CLI args on top of the default generation config
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
pad_token_id, eos_token_ids = self.parse_eos_tokens(
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
)
deprecated_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
@ -509,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand):
@staticmethod
def parse_eos_tokens(
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
tokenizer: PreTrainedTokenizer,
generation_config: GenerationConfig,
eos_tokens: Optional[str],
eos_token_ids: Optional[str],
) -> tuple[int, list[int]]:
"""Retrieves the pad token ID and all possible EOS token IDs."""
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
if generation_config.pad_token_id is None:
pad_token_id = generation_config.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
pad_token_id = generation_config.pad_token_id
all_eos_token_ids = []
@ -526,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand):
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
all_eos_token_ids.append(generation_config.eos_token_id)
return pad_token_id, all_eos_token_ids
@ -683,7 +698,7 @@ class ChatCommand(BaseTransformersCLICommand):
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
# if not verbose -> disable warnings, progress bars, etc in the chat interface
if not args.verbose: