Add tests for legacy load by url and fix bugs (#19078)

This commit is contained in:
Sylvain Gugger 2022-09-16 17:20:02 -04:00 committed by GitHub
parent ae219532e3
commit ca485e562b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 62 additions and 5 deletions

View File

@ -680,7 +680,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME

View File

@ -2418,7 +2418,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename

View File

@ -2005,7 +2005,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename

View File

@ -1670,7 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
init_configuration = {}
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "

View File

@ -360,6 +360,12 @@ class ConfigTestUtils(unittest.TestCase):
# This check we did call the fake head request
mock_head.assert_called()
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
_ = BertConfig.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
)
class ConfigurationVersioningTest(unittest.TestCase):
def test_local_versioning(self):

View File

@ -182,6 +182,12 @@ class FeatureExtractorUtilTester(unittest.TestCase):
# This check we did call the fake head request
mock_head.assert_called()
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
_ = Wav2Vec2FeatureExtractor.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json"
)
@is_staging_test
class FeatureExtractorPushToHubTester(unittest.TestCase):

View File

@ -33,6 +33,7 @@ import numpy as np
import transformers
from huggingface_hub import HfFolder, delete_repo, set_access_token
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError
from transformers import (
AutoConfig,
@ -2949,6 +2950,26 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request
mock_head.assert_called()
def test_load_from_one_file(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f
)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = BertModel.from_pretrained(tmp_file, config=config)
finally:
os.remove(tmp_file)
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = BertModel.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
)
@require_torch
@is_staging_test

View File

@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints
from datasets import Dataset
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase):
# This check we did call the fake head request
mock_head.assert_called()
def test_load_from_one_file(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(tmp_file, config=config)
finally:
os.remove(tmp_file)
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
)
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:

View File

@ -3891,15 +3891,20 @@ class TokenizerUtilTester(unittest.TestCase):
mock_head.assert_called()
def test_legacy_load_from_one_file(self):
# This test is for deprecated behavior and can be removed in v5
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f)
AlbertTokenizer.from_pretrained(tmp_file)
_ = AlbertTokenizer.from_pretrained(tmp_file)
finally:
os.remove(tmp_file)
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
_ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert-base-v1/resolve/main/spiece.model")
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):