mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[chat] use the checkpoint's generation_config.json
as base parameterization (#38330)
* use model gen config * unwanted diff
This commit is contained in:
parent
008e0d87c5
commit
80902ae9b1
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@ -28,7 +29,13 @@ from typing import Optional
|
||||
import yaml
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer, logging
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
TextIteratorStreamer,
|
||||
logging,
|
||||
)
|
||||
from transformers.utils import is_rich_available, is_torch_available
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
@ -45,7 +52,7 @@ if is_rich_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
|
||||
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
@ -248,7 +255,9 @@ class ChatArguments:
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||
eos_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
|
||||
metadata={
|
||||
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
|
||||
},
|
||||
)
|
||||
eos_token_ids: Optional[str] = field(
|
||||
default=None,
|
||||
@ -469,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
return processed_generate_flags
|
||||
|
||||
def get_generation_parameterization(
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
"""
|
||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||
"""
|
||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||
# No generation config arg provided -> use default generation config, apply CLI defaults
|
||||
if args.generation_config is None:
|
||||
generation_config = GenerationConfig()
|
||||
# We start off from the checkpoint's generation config
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
# Apply deprecated CLI args on top of the default generation config
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(
|
||||
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
|
||||
)
|
||||
deprecated_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"do_sample": args.do_sample,
|
||||
@ -509,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
@staticmethod
|
||||
def parse_eos_tokens(
|
||||
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
generation_config: GenerationConfig,
|
||||
eos_tokens: Optional[str],
|
||||
eos_token_ids: Optional[str],
|
||||
) -> tuple[int, list[int]]:
|
||||
"""Retrieves the pad token ID and all possible EOS token IDs."""
|
||||
if tokenizer.pad_token_id is None:
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
pad_token_id = generation_config.eos_token_id
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
|
||||
all_eos_token_ids = []
|
||||
|
||||
@ -526,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||
|
||||
if len(all_eos_token_ids) == 0:
|
||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||
all_eos_token_ids.append(generation_config.eos_token_id)
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
|
||||
@ -683,7 +698,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
|
||||
|
||||
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
||||
if not args.verbose:
|
||||
|
Loading…
Reference in New Issue
Block a user