mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
[chat] use the checkpoint's generation_config.json
as base parameterization (#38330)
* use model gen config * unwanted diff
This commit is contained in:
parent
008e0d87c5
commit
80902ae9b1
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -28,7 +29,13 @@ from typing import Optional
|
|||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub.utils import disable_progress_bars
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer, logging
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
GenerationConfig,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
TextIteratorStreamer,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from transformers.utils import is_rich_available, is_torch_available
|
from transformers.utils import is_rich_available, is_torch_available
|
||||||
|
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
@ -45,7 +52,7 @@ if is_rich_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||||
@ -248,7 +255,9 @@ class ChatArguments:
|
|||||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||||
eos_tokens: Optional[str] = field(
|
eos_tokens: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
|
metadata={
|
||||||
|
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
eos_token_ids: Optional[str] = field(
|
eos_token_ids: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -469,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
return processed_generate_flags
|
return processed_generate_flags
|
||||||
|
|
||||||
def get_generation_parameterization(
|
def get_generation_parameterization(
|
||||||
self, args: ChatArguments, tokenizer: AutoTokenizer
|
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
|
||||||
) -> tuple[GenerationConfig, dict]:
|
) -> tuple[GenerationConfig, dict]:
|
||||||
"""
|
"""
|
||||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||||
"""
|
"""
|
||||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
# No generation config arg provided -> use default generation config, apply CLI defaults
|
||||||
if args.generation_config is None:
|
if args.generation_config is None:
|
||||||
generation_config = GenerationConfig()
|
# We start off from the checkpoint's generation config
|
||||||
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
# Apply deprecated CLI args on top of the default generation config
|
# Apply deprecated CLI args on top of the default generation config
|
||||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
pad_token_id, eos_token_ids = self.parse_eos_tokens(
|
||||||
|
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
|
||||||
|
)
|
||||||
deprecated_kwargs = {
|
deprecated_kwargs = {
|
||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"do_sample": args.do_sample,
|
"do_sample": args.do_sample,
|
||||||
@ -509,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_eos_tokens(
|
def parse_eos_tokens(
|
||||||
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
eos_tokens: Optional[str],
|
||||||
|
eos_token_ids: Optional[str],
|
||||||
) -> tuple[int, list[int]]:
|
) -> tuple[int, list[int]]:
|
||||||
"""Retrieves the pad token ID and all possible EOS token IDs."""
|
"""Retrieves the pad token ID and all possible EOS token IDs."""
|
||||||
if tokenizer.pad_token_id is None:
|
if generation_config.pad_token_id is None:
|
||||||
pad_token_id = tokenizer.eos_token_id
|
pad_token_id = generation_config.eos_token_id
|
||||||
else:
|
else:
|
||||||
pad_token_id = tokenizer.pad_token_id
|
pad_token_id = generation_config.pad_token_id
|
||||||
|
|
||||||
all_eos_token_ids = []
|
all_eos_token_ids = []
|
||||||
|
|
||||||
@ -526,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||||
|
|
||||||
if len(all_eos_token_ids) == 0:
|
if len(all_eos_token_ids) == 0:
|
||||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
all_eos_token_ids.append(generation_config.eos_token_id)
|
||||||
|
|
||||||
return pad_token_id, all_eos_token_ids
|
return pad_token_id, all_eos_token_ids
|
||||||
|
|
||||||
@ -683,7 +698,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
|
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
|
||||||
|
|
||||||
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
|
Loading…
Reference in New Issue
Block a user