mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Reimplement "Automatic safetensors conversion when lacking these files" (#29846)
* Automatic safetensors conversion when lacking these files (#29390) * Automatic safetensors conversion when lacking these files * Remove debug * Thread name * Typo * Ensure that raises do not affect the main thread * Catch all errors
This commit is contained in:
parent
a81cf9ee90
commit
4d8427f739
@ -29,6 +29,7 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
@ -3228,9 +3229,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
if resolved_archive_file is None:
|
|
||||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
if resolved_archive_file is not None:
|
||||||
# message.
|
if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
|
||||||
|
# If the PyTorch file was found, check if there is a safetensors file on the repository
|
||||||
|
# If there is no safetensors file on the repositories, start an auto conversion
|
||||||
|
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
|
||||||
|
has_file_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"proxies": proxies,
|
||||||
|
"token": token,
|
||||||
|
}
|
||||||
|
cached_file_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"force_download": force_download,
|
||||||
|
"resume_download": resume_download,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
"user_agent": user_agent,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"_raise_exceptions_for_gated_repo": False,
|
||||||
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
|
"_commit_hash": commit_hash,
|
||||||
|
**has_file_kwargs,
|
||||||
|
}
|
||||||
|
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
|
||||||
|
Thread(
|
||||||
|
target=auto_conversion,
|
||||||
|
args=(pretrained_model_name_or_path,),
|
||||||
|
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
|
||||||
|
name="Thread-autoconversion",
|
||||||
|
).start()
|
||||||
|
else:
|
||||||
|
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
|
||||||
|
# We try those to give a helpful error message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
|
@ -84,24 +84,28 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
|
|||||||
return sha
|
return sha
|
||||||
|
|
||||||
|
|
||||||
def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
|
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
|
||||||
api = HfApi(token=cached_file_kwargs.get("token"))
|
try:
|
||||||
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
api = HfApi(token=cached_file_kwargs.get("token"))
|
||||||
|
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
||||||
|
|
||||||
if sha is None:
|
if sha is None:
|
||||||
return None, None
|
return None, None
|
||||||
cached_file_kwargs["revision"] = sha
|
cached_file_kwargs["revision"] = sha
|
||||||
del cached_file_kwargs["_commit_hash"]
|
del cached_file_kwargs["_commit_hash"]
|
||||||
|
|
||||||
# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
|
# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
|
||||||
# description.
|
# description.
|
||||||
sharded = api.file_exists(
|
sharded = api.file_exists(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
"model.safetensors.index.json",
|
"model.safetensors.index.json",
|
||||||
revision=sha,
|
revision=sha,
|
||||||
token=cached_file_kwargs.get("token"),
|
token=cached_file_kwargs.get("token"),
|
||||||
)
|
)
|
||||||
filename = "model.safetensors.index.json" if sharded else "model.safetensors"
|
filename = "model.safetensors.index.json" if sharded else "model.safetensors"
|
||||||
|
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
return resolved_archive_file, sha, sharded
|
return resolved_archive_file, sha, sharded
|
||||||
|
except Exception as e:
|
||||||
|
if not ignore_errors_during_conversion:
|
||||||
|
raise e
|
||||||
|
@ -20,6 +20,7 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
import uuid
|
import uuid
|
||||||
@ -1428,7 +1429,7 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
|
|||||||
bot_opened_pr_title = None
|
bot_opened_pr_title = None
|
||||||
|
|
||||||
for discussion in discussions:
|
for discussion in discussions:
|
||||||
if discussion.author == "SFconvertBot":
|
if discussion.author == "SFconvertbot":
|
||||||
bot_opened_pr = True
|
bot_opened_pr = True
|
||||||
bot_opened_pr_title = discussion.title
|
bot_opened_pr_title = discussion.title
|
||||||
|
|
||||||
@ -1451,6 +1452,51 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
|
|||||||
with self.assertRaises(EnvironmentError):
|
with self.assertRaises(EnvironmentError):
|
||||||
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
|
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
|
||||||
|
|
||||||
|
def test_absence_of_safetensors_triggers_conversion(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
# Push a model on `main`
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||||
|
|
||||||
|
# Download the model that doesn't have safetensors
|
||||||
|
BertModel.from_pretrained(self.repo_name, token=self.token)
|
||||||
|
|
||||||
|
for thread in threading.enumerate():
|
||||||
|
if thread.name == "Thread-autoconversion":
|
||||||
|
thread.join(timeout=10)
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
|
||||||
|
bot_opened_pr = None
|
||||||
|
bot_opened_pr_title = None
|
||||||
|
|
||||||
|
for discussion in discussions:
|
||||||
|
if discussion.author == "SFconvertbot":
|
||||||
|
bot_opened_pr = True
|
||||||
|
bot_opened_pr_title = discussion.title
|
||||||
|
|
||||||
|
self.assertTrue(bot_opened_pr)
|
||||||
|
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
@mock.patch("transformers.safetensors_conversion.spawn_conversion")
|
||||||
|
def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):
|
||||||
|
spawn_conversion_mock.side_effect = HTTPError()
|
||||||
|
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
# Push a model on `main`
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||||
|
|
||||||
|
# The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread
|
||||||
|
BertModel.from_pretrained(self.repo_name, token=self.token)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
Loading…
Reference in New Issue
Block a user