Correctly list the chat template file in the Tokenizer saved files list (#34974)

* Correctly list the chat template file in the saved files list

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add save file checking to test

* make fixup

* better filename handling

* make fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Matt 2025-01-07 19:11:02 +00:00 committed by GitHub
parent cdca3cf9e3
commit a7d1441d65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -2429,6 +2429,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
tokenizer_config.update(self.extra_special_tokens)
saved_raw_chat_template = False
if self.chat_template is not None:
if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
@ -2439,6 +2440,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
elif kwargs.get("save_raw_chat_template", False):
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
saved_raw_chat_template = True
logger.info(f"chat template saved in {chat_template_file}")
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
@ -2498,6 +2500,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
if saved_raw_chat_template:
file_names += (chat_template_file,)
save_files = self._save_pretrained(
save_directory=save_directory,

View File

@ -1107,7 +1107,9 @@ class TokenizerTesterMixin:
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
save_files = tokenizer.save_pretrained(tmp_dir_name)
# Check we aren't saving a chat_template.jinja file
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
@ -1117,7 +1119,9 @@ class TokenizerTesterMixin:
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
# Check we are saving a chat_template.jinja file
self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files))
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
self.assertTrue(chat_template_file.is_file())
self.assertEqual(chat_template_file.read_text(), dummy_template)