mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Correctly list the chat template file in the Tokenizer saved files list (#34974)
* Correctly list the chat template file in the saved files list * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add save file checking to test * make fixup * better filename handling * make fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
cdca3cf9e3
commit
a7d1441d65
@ -2429,6 +2429,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
|
||||
tokenizer_config.update(self.extra_special_tokens)
|
||||
|
||||
saved_raw_chat_template = False
|
||||
if self.chat_template is not None:
|
||||
if isinstance(self.chat_template, dict):
|
||||
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
|
||||
@ -2439,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
elif kwargs.get("save_raw_chat_template", False):
|
||||
with open(chat_template_file, "w", encoding="utf-8") as f:
|
||||
f.write(self.chat_template)
|
||||
saved_raw_chat_template = True
|
||||
logger.info(f"chat template saved in {chat_template_file}")
|
||||
if "chat_template" in tokenizer_config:
|
||||
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
|
||||
@ -2498,6 +2500,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
|
||||
|
||||
file_names = (tokenizer_config_file, special_tokens_map_file)
|
||||
if saved_raw_chat_template:
|
||||
file_names += (chat_template_file,)
|
||||
|
||||
save_files = self._save_pretrained(
|
||||
save_directory=save_directory,
|
||||
|
@ -1107,7 +1107,9 @@ class TokenizerTesterMixin:
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
save_files = tokenizer.save_pretrained(tmp_dir_name)
|
||||
# Check we aren't saving a chat_template.jinja file
|
||||
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
@ -1117,7 +1119,9 @@ class TokenizerTesterMixin:
|
||||
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
|
||||
save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
|
||||
# Check we are saving a chat_template.jinja file
|
||||
self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files))
|
||||
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
|
||||
self.assertTrue(chat_template_file.is_file())
|
||||
self.assertEqual(chat_template_file.read_text(), dummy_template)
|
||||
|
Loading…
Reference in New Issue
Block a user