[add-new-model-like] Robust search & proper outer '),' in tokenizer mapping (#38703)

* [add-new-model-like] Robust search & proper outer '),' in tokenizer mapping

* code-style: arrange the importation in add_new_model_like.py

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
alexzms 2025-06-10 20:25:12 +08:00 committed by GitHub
parent 8340e8746e
commit 8ff22e9d3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# 1. Standard library
import difflib
import json
import os
@ -28,7 +29,12 @@ import yaml
from ..models import auto as auto_module
from ..models.auto.configuration_auto import model_type_to_module_name
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
from ..utils import (
is_flax_available,
is_tf_available,
is_torch_available,
logging,
)
from . import BaseTransformersCLICommand
from .add_fast_image_processor import add_fast_image_processor
@ -1009,10 +1015,11 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
content = f.read()
pattern_tokenizer = re.compile(r"^\s*TOKENIZER_MAPPING_NAMES\s*=\s*OrderedDict\b")
lines = content.split("\n")
idx = 0
# First we get to the TOKENIZER_MAPPING_NAMES block.
while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
while not pattern_tokenizer.search(lines[idx]):
idx += 1
idx += 1
@ -1024,9 +1031,12 @@ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model
# Otherwise it takes several lines until we get to a "),"
else:
block = []
# should change to " )," instead of " ),"
while not lines[idx].startswith(" ),"):
block.append(lines[idx])
idx += 1
# if the lines[idx] does start with " )," we still need it in our block
block.append(lines[idx])
block = "\n".join(block)
idx += 1