Fix tokenizer load from one file (#19073)

* Fix tokenizer load from one file

* Add a test

* Style

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger 2022-09-16 16:11:47 -04:00 committed by GitHub
parent 773314ab80
commit 9017ba4ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 0 deletions

View File

@ -1726,6 +1726,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
elif os.path.isfile(file_path):
resolved_vocab_files[file_id] = file_path
elif is_remote_url(file_path):
resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
else:

View File

@ -31,6 +31,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import HfFolder, delete_repo, set_access_token
from huggingface_hub.file_download import http_get
from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import (
@ -3889,6 +3890,16 @@ class TokenizerUtilTester(unittest.TestCase):
# This check we did call the fake head request
mock_head.assert_called()
def test_legacy_load_from_one_file(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f)
AlbertTokenizer.from_pretrained(tmp_file)
finally:
os.remove(tmp_file)
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):