mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
parent
1e6b546ea6
commit
cdfb018d03
@ -298,24 +298,6 @@ def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
|||||||
return first_tuple[1].device
|
return first_tuple[1].device
|
||||||
|
|
||||||
|
|
||||||
def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
|
||||||
"""
|
|
||||||
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return next(parameter.parameters()).dtype
|
|
||||||
except StopIteration:
|
|
||||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
|
||||||
|
|
||||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
|
||||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
|
||||||
return tuples
|
|
||||||
|
|
||||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
|
||||||
first_tuple = next(gen)
|
|
||||||
return first_tuple[1].dtype
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||||
"""
|
"""
|
||||||
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||||
@ -365,17 +347,6 @@ def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
|||||||
return last_dtype
|
return last_dtype
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_float_dtype(state_dict):
|
|
||||||
"""
|
|
||||||
Returns the first found floating dtype in `state_dict` or asserts if none were found.
|
|
||||||
"""
|
|
||||||
for t in state_dict.values():
|
|
||||||
if t.is_floating_point():
|
|
||||||
return t.dtype
|
|
||||||
|
|
||||||
raise ValueError("couldn't find any floating point dtypes in state_dict")
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_dtype(state_dict):
|
def get_state_dict_dtype(state_dict):
|
||||||
"""
|
"""
|
||||||
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
|
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
|
||||||
|
@ -527,78 +527,6 @@ def cached_files(
|
|||||||
return resolved_files
|
return resolved_files
|
||||||
|
|
||||||
|
|
||||||
# TODO cyril: Deprecated and should be removed in 4.51
|
|
||||||
def get_file_from_repo(
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path_or_repo (`str` or `os.PathLike`):
|
|
||||||
This can be either:
|
|
||||||
|
|
||||||
- a string, the *model id* of a model repo on huggingface.co.
|
|
||||||
- a path to a *directory* potentially containing the file.
|
|
||||||
filename (`str`):
|
|
||||||
The name of the file to locate in `path_or_repo`.
|
|
||||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
|
||||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
|
||||||
cache should not be used.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
|
||||||
exist.
|
|
||||||
resume_download:
|
|
||||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
|
||||||
Will be removed in v5 of Transformers.
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
|
||||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
||||||
identifier allowed by git.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
If `True`, will only try to load the tokenizer configuration from local files.
|
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
|
||||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
|
||||||
specify the folder name here.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
Passing `token=True` is required when you want to use a private model.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
|
|
||||||
file does not exist.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Download a tokenizer configuration from huggingface.co and cache.
|
|
||||||
tokenizer_config = get_file_from_repo("google-bert/bert-base-uncased", "tokenizer_config.json")
|
|
||||||
# This model does not have a tokenizer config so the result will be None.
|
|
||||||
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
logger.warning(
|
|
||||||
"`get_file_from_repo` is deprecated and will be removed in version 4.51. Use `cached_file` instead."
|
|
||||||
)
|
|
||||||
return cached_file(
|
|
||||||
*args,
|
|
||||||
_raise_exceptions_for_gated_repo=False,
|
|
||||||
_raise_exceptions_for_missing_entries=False,
|
|
||||||
_raise_exceptions_for_connection_errors=False,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_url(url, proxies=None):
|
def download_url(url, proxies=None):
|
||||||
"""
|
"""
|
||||||
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
|
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
|
||||||
|
Loading…
Reference in New Issue
Block a user