[2/N] Use pyupgrade --py39-plus to improve code (#36857)

Use pyupgrade --py39-plus to improve code
This commit is contained in:
cyyever 2025-03-24 23:42:25 +08:00 committed by GitHub
parent a6ecb54159
commit 00d077267a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 116 additions and 129 deletions

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# #
@ -20,7 +19,7 @@ import json
import os import os
import re import re
import warnings import warnings
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
from packaging import version from packaging import version
@ -196,11 +195,11 @@ class PretrainedConfig(PushToHubMixin):
model_type: str = "" model_type: str = ""
base_config_key: str = "" base_config_key: str = ""
sub_configs: Dict[str, "PretrainedConfig"] = {} sub_configs: dict[str, "PretrainedConfig"] = {}
is_composition: bool = False is_composition: bool = False
attribute_map: Dict[str, str] = {} attribute_map: dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None base_model_tp_plan: Optional[dict[str, Any]] = None
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None
_auto_class: Optional[str] = None _auto_class: Optional[str] = None
def __setattr__(self, key, value): def __setattr__(self, key, value):
@ -574,7 +573,7 @@ class PretrainedConfig(PushToHubMixin):
@classmethod @classmethod
def get_config_dict( def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
[`PretrainedConfig`] using `from_dict`. [`PretrainedConfig`] using `from_dict`.
@ -609,7 +608,7 @@ class PretrainedConfig(PushToHubMixin):
@classmethod @classmethod
def _get_config_dict( def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None) resume_download = kwargs.pop("resume_download", None)
@ -667,13 +666,13 @@ class PretrainedConfig(PushToHubMixin):
if resolved_config_file is None: if resolved_config_file is None:
return None, kwargs return None, kwargs
commit_hash = extract_commit_hash(resolved_config_file, commit_hash) commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError: except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
raise raise
except Exception: except Exception:
# For any other exception, we throw a generic error. # For any other exception, we throw a generic error.
raise EnvironmentError( raise OSError(
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it" f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same" " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory" f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
@ -689,9 +688,7 @@ class PretrainedConfig(PushToHubMixin):
config_dict["_commit_hash"] = commit_hash config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)
if is_local: if is_local:
logger.info(f"loading configuration file {resolved_config_file}") logger.info(f"loading configuration file {resolved_config_file}")
@ -714,7 +711,7 @@ class PretrainedConfig(PushToHubMixin):
return config_dict, kwargs return config_dict, kwargs
@classmethod @classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "PretrainedConfig":
""" """
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
@ -792,7 +789,7 @@ class PretrainedConfig(PushToHubMixin):
@classmethod @classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
@ -803,10 +800,9 @@ class PretrainedConfig(PushToHubMixin):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
def __iter__(self): def __iter__(self):
for attr in self.__dict__: yield from self.__dict__
yield attr
def to_diff_dict(self) -> Dict[str, Any]: def to_diff_dict(self) -> dict[str, Any]:
""" """
Removes all attributes from config which correspond to the default config attributes for better readability and Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary. serializes to a Python dictionary.
@ -874,7 +870,7 @@ class PretrainedConfig(PushToHubMixin):
return serializable_config_dict return serializable_config_dict
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """
Serializes this instance to a Python dictionary. Serializes this instance to a Python dictionary.
@ -954,7 +950,7 @@ class PretrainedConfig(PushToHubMixin):
with open(json_file_path, "w", encoding="utf-8") as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff)) writer.write(self.to_json_string(use_diff=use_diff))
def update(self, config_dict: Dict[str, Any]): def update(self, config_dict: dict[str, Any]):
""" """
Updates attributes of this class with attributes from `config_dict`. Updates attributes of this class with attributes from `config_dict`.
@ -1002,7 +998,7 @@ class PretrainedConfig(PushToHubMixin):
setattr(self, k, v) setattr(self, k, v)
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: def dict_torch_dtype_to_str(self, d: dict[str, Any]) -> None:
""" """
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
@ -1044,7 +1040,7 @@ class PretrainedConfig(PushToHubMixin):
cls._auto_class = auto_class cls._auto_class = auto_class
@staticmethod @staticmethod
def _get_global_generation_defaults() -> Dict[str, Any]: def _get_global_generation_defaults() -> dict[str, Any]:
return { return {
"max_length": 20, "max_length": 20,
"min_length": 0, "min_length": 0,
@ -1073,7 +1069,7 @@ class PretrainedConfig(PushToHubMixin):
"begin_suppress_tokens": None, "begin_suppress_tokens": None,
} }
def _get_non_default_generation_parameters(self) -> Dict[str, Any]: def _get_non_default_generation_parameters(self) -> dict[str, Any]:
""" """
Gets the non-default generation parameters on the PretrainedConfig instance Gets the non-default generation parameters on the PretrainedConfig instance
""" """
@ -1148,7 +1144,7 @@ class PretrainedConfig(PushToHubMixin):
return self return self
def get_configuration_file(configuration_files: List[str]) -> str: def get_configuration_file(configuration_files: list[str]) -> str:
""" """
Get the configuration file to use for this version of transformers. Get the configuration file to use for this version of transformers.

View File

@ -1587,7 +1587,7 @@ class TikTokenConverter:
from tiktoken.load import load_tiktoken_bpe from tiktoken.load import load_tiktoken_bpe
except Exception: except Exception:
raise ValueError( raise ValueError(
"`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`." "`tiktoken` is required to read a `tiktoken` file. Install it with `pip install tiktoken`."
) )
bpe_ranks = load_tiktoken_bpe(tiktoken_url) bpe_ranks = load_tiktoken_bpe(tiktoken_url)

View File

@ -206,7 +206,7 @@ class DebugUnderflowOverflow:
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
def batch_end_frame(self): def batch_end_frame(self):
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n") self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n")
def create_frame(self, module, input, output): def create_frame(self, module, input, output):
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. # Copyright 2021 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -24,11 +23,10 @@ import shutil
import signal import signal
import sys import sys
import threading import threading
import typing
import warnings import warnings
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from huggingface_hub import try_to_load_from_cache from huggingface_hub import try_to_load_from_cache
@ -84,7 +82,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
importlib.invalidate_caches() importlib.invalidate_caches()
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]: def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
""" """
Get the list of modules that are relatively imported in a module file. Get the list of modules that are relatively imported in a module file.
@ -94,7 +92,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
Returns: Returns:
`List[str]`: The list of relative imports in the module. `List[str]`: The list of relative imports in the module.
""" """
with open(module_file, "r", encoding="utf-8") as f: with open(module_file, encoding="utf-8") as f:
content = f.read() content = f.read()
# Imports of the form `import .xxx` # Imports of the form `import .xxx`
@ -105,7 +103,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
return list(set(relative_imports)) return list(set(relative_imports))
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]: def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]:
""" """
Get the list of all files that are needed for a given module. Note that this function recurses through the relative Get the list of all files that are needed for a given module. Note that this function recurses through the relative
imports (if a imports b and b imports c, it will return module files for b and c). imports (if a imports b and b imports c, it will return module files for b and c).
@ -138,7 +136,7 @@ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]
return all_relative_imports return all_relative_imports
def get_imports(filename: Union[str, os.PathLike]) -> List[str]: def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
""" """
Extracts all the libraries (not relative imports this time) that are imported in a file. Extracts all the libraries (not relative imports this time) that are imported in a file.
@ -148,7 +146,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
Returns: Returns:
`List[str]`: The list of all packages required to use the input module. `List[str]`: The list of all packages required to use the input module.
""" """
with open(filename, "r", encoding="utf-8") as f: with open(filename, encoding="utf-8") as f:
content = f.read() content = f.read()
# filter out try/except block so in custom code we can have try/except imports # filter out try/except block so in custom code we can have try/except imports
@ -168,7 +166,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
return list(set(imports)) return list(set(imports))
def check_imports(filename: Union[str, os.PathLike]) -> List[str]: def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
""" """
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
library is missing. library is missing.
@ -208,7 +206,7 @@ def get_class_in_module(
module_path: Union[str, os.PathLike], module_path: Union[str, os.PathLike],
*, *,
force_reload: bool = False, force_reload: bool = False,
) -> typing.Type: ) -> type:
""" """
Import a module on the cache directory for modules and extract a class from it. Import a module on the cache directory for modules and extract a class from it.
@ -235,7 +233,7 @@ def get_class_in_module(
module_spec = importlib.util.spec_from_file_location(name, location=module_file) module_spec = importlib.util.spec_from_file_location(name, location=module_file)
# Hash the module file and all its relative imports to check if we need to reload it # Hash the module file and all its relative imports to check if we need to reload it
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType module: ModuleType
@ -258,7 +256,7 @@ def get_cached_module_file(
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: Optional[bool] = None, resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
@ -358,7 +356,7 @@ def get_cached_module_file(
if not is_local and cached_module != resolved_module_file: if not is_local and cached_module != resolved_module_file:
new_files.append(module_file) new_files.append(module_file)
except EnvironmentError: except OSError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise raise
@ -434,14 +432,14 @@ def get_class_from_dynamic_module(
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: Optional[bool] = None, resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
repo_type: Optional[str] = None, repo_type: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> typing.Type: ) -> type:
""" """
Extracts a class from a module file, present in the local folder or repository of a model. Extracts a class from a module file, present in the local folder or repository of a model.
@ -553,7 +551,7 @@ def get_class_from_dynamic_module(
return get_class_in_module(class_name, final_module, force_reload=force_download) return get_class_in_module(class_name, final_module, force_reload=force_download)
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[dict] = None) -> list[str]:
""" """
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
adds the proper fields in a config. adds the proper fields in a config.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. # Copyright 2021 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -21,7 +20,7 @@ import json
import os import os
import warnings import warnings
from collections import UserDict from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np import numpy as np
@ -74,7 +73,7 @@ class BatchFeature(UserDict):
initialization. initialization.
""" """
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
super().__init__(data) super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type) self.convert_to_tensors(tensor_type=tensor_type)
@ -450,7 +449,7 @@ class FeatureExtractionMixin(PushToHubMixin):
@classmethod @classmethod
def get_feature_extractor_dict( def get_feature_extractor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`. feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
@ -521,13 +520,13 @@ class FeatureExtractionMixin(PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
) )
except EnvironmentError: except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
raise raise
except Exception: except Exception:
# For any other exception, we throw a generic error. # For any other exception, we throw a generic error.
raise EnvironmentError( raise OSError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
@ -536,12 +535,12 @@ class FeatureExtractionMixin(PushToHubMixin):
try: try:
# Load feature_extractor dict # Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
feature_extractor_dict = json.loads(text) feature_extractor_dict = json.loads(text)
except json.JSONDecodeError: except json.JSONDecodeError:
raise EnvironmentError( raise OSError(
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file." f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
) )
@ -565,7 +564,7 @@ class FeatureExtractionMixin(PushToHubMixin):
return feature_extractor_dict, kwargs return feature_extractor_dict, kwargs
@classmethod @classmethod
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor: def from_dict(cls, feature_extractor_dict: dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
""" """
Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
parameters. parameters.
@ -601,7 +600,7 @@ class FeatureExtractionMixin(PushToHubMixin):
else: else:
return feature_extractor return feature_extractor
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """
Serializes this instance to a Python dictionary. Returns: Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
@ -628,7 +627,7 @@ class FeatureExtractionMixin(PushToHubMixin):
A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
object instantiated from that JSON file. object instantiated from that JSON file.
""" """
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
feature_extractor_dict = json.loads(text) feature_extractor_dict = json.loads(text)
return cls(**feature_extractor_dict) return cls(**feature_extractor_dict)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. # Copyright 2020 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -19,7 +18,7 @@ import json
import os import os
import warnings import warnings
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Any, Optional, TypeVar, Union
import numpy as np import numpy as np
import requests import requests
@ -98,7 +97,7 @@ class ImageProcessingMixin(PushToHubMixin):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls: Type[ImageProcessorType], cls: type[ImageProcessorType],
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
@ -274,7 +273,7 @@ class ImageProcessingMixin(PushToHubMixin):
@classmethod @classmethod
def get_image_processor_dict( def get_image_processor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`. image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
@ -351,13 +350,13 @@ class ImageProcessingMixin(PushToHubMixin):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
) )
except EnvironmentError: except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
raise raise
except Exception: except Exception:
# For any other exception, we throw a generic error. # For any other exception, we throw a generic error.
raise EnvironmentError( raise OSError(
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
@ -366,12 +365,12 @@ class ImageProcessingMixin(PushToHubMixin):
try: try:
# Load image_processor dict # Load image_processor dict
with open(resolved_image_processor_file, "r", encoding="utf-8") as reader: with open(resolved_image_processor_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
image_processor_dict = json.loads(text) image_processor_dict = json.loads(text)
except json.JSONDecodeError: except json.JSONDecodeError:
raise EnvironmentError( raise OSError(
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file." f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
) )
@ -393,7 +392,7 @@ class ImageProcessingMixin(PushToHubMixin):
return image_processor_dict, kwargs return image_processor_dict, kwargs
@classmethod @classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
""" """
Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters. Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
@ -437,7 +436,7 @@ class ImageProcessingMixin(PushToHubMixin):
else: else:
return image_processor return image_processor
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """
Serializes this instance to a Python dictionary. Serializes this instance to a Python dictionary.
@ -463,7 +462,7 @@ class ImageProcessingMixin(PushToHubMixin):
A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
instantiated from that JSON file. instantiated from that JSON file.
""" """
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
image_processor_dict = json.loads(text) image_processor_dict = json.loads(text)
return cls(**image_processor_dict) return cls(**image_processor_dict)
@ -529,7 +528,7 @@ class ImageProcessingMixin(PushToHubMixin):
cls._auto_class = auto_class cls._auto_class = auto_class
def fetch_images(self, image_url_or_urls: Union[str, List[str]]): def fetch_images(self, image_url_or_urls: Union[str, list[str]]):
""" """
Convert a single or a list of urls into the corresponding `PIL.Image` objects. Convert a single or a list of urls into the corresponding `PIL.Image` objects.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team. # Copyright 2018 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -20,7 +19,7 @@ import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
import requests import requests
import yaml import yaml
@ -196,7 +195,7 @@ class ModelCard:
# Load model card # Load model card
modelcard = cls.from_json_file(resolved_model_card_file) modelcard = cls.from_json_file(resolved_model_card_file)
except (EnvironmentError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
# We fall back on creating an empty model card # We fall back on creating an empty model card
modelcard = cls() modelcard = cls()
@ -223,7 +222,7 @@ class ModelCard:
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters.""" """Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) dict_obj = json.loads(text)
return cls(**dict_obj) return cls(**dict_obj)
@ -357,18 +356,18 @@ def _get_mapping_values(mapping):
@dataclass @dataclass
class TrainingSummary: class TrainingSummary:
model_name: str model_name: str
language: Optional[Union[str, List[str]]] = None language: Optional[Union[str, list[str]]] = None
license: Optional[str] = None license: Optional[str] = None
tags: Optional[Union[str, List[str]]] = None tags: Optional[Union[str, list[str]]] = None
finetuned_from: Optional[str] = None finetuned_from: Optional[str] = None
tasks: Optional[Union[str, List[str]]] = None tasks: Optional[Union[str, list[str]]] = None
dataset: Optional[Union[str, List[str]]] = None dataset: Optional[Union[str, list[str]]] = None
dataset_tags: Optional[Union[str, List[str]]] = None dataset_tags: Optional[Union[str, list[str]]] = None
dataset_args: Optional[Union[str, List[str]]] = None dataset_args: Optional[Union[str, list[str]]] = None
dataset_metadata: Optional[Dict[str, Any]] = None dataset_metadata: Optional[dict[str, Any]] = None
eval_results: Optional[Dict[str, float]] = None eval_results: Optional[dict[str, float]] = None
eval_lines: Optional[List[str]] = None eval_lines: Optional[list[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None hyperparameters: Optional[dict[str, Any]] = None
source: Optional[str] = "trainer" source: Optional[str] = "trainer"
def __post_init__(self): def __post_init__(self):

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. # Copyright 2022 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -24,7 +23,7 @@ import sys
import typing import typing
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union from typing import Any, Callable, Optional, TypedDict, Union
import numpy as np import numpy as np
import typing_extensions import typing_extensions
@ -123,9 +122,9 @@ class TextKwargs(TypedDict, total=False):
The side on which padding will be applied. The side on which padding will be applied.
""" """
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] text_pair_target: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
add_special_tokens: Optional[bool] add_special_tokens: Optional[bool]
padding: Union[bool, str, PaddingStrategy] padding: Union[bool, str, PaddingStrategy]
truncation: Union[bool, str, TruncationStrategy] truncation: Union[bool, str, TruncationStrategy]
@ -184,17 +183,17 @@ class ImagesKwargs(TypedDict, total=False):
""" """
do_resize: Optional[bool] do_resize: Optional[bool]
size: Optional[Dict[str, int]] size: Optional[dict[str, int]]
size_divisor: Optional[int] size_divisor: Optional[int]
crop_size: Optional[Dict[str, int]] crop_size: Optional[dict[str, int]]
resample: Optional[Union["PILImageResampling", int]] resample: Optional[Union["PILImageResampling", int]]
do_rescale: Optional[bool] do_rescale: Optional[bool]
rescale_factor: Optional[float] rescale_factor: Optional[float]
do_normalize: Optional[bool] do_normalize: Optional[bool]
image_mean: Optional[Union[float, List[float]]] image_mean: Optional[Union[float, list[float]]]
image_std: Optional[Union[float, List[float]]] image_std: Optional[Union[float, list[float]]]
do_pad: Optional[bool] do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]] pad_size: Optional[dict[str, int]]
do_center_crop: Optional[bool] do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension] data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]] input_data_format: Optional[Union[str, ChannelDimension]]
@ -235,14 +234,14 @@ class VideosKwargs(TypedDict, total=False):
""" """
do_resize: Optional[bool] do_resize: Optional[bool]
size: Optional[Dict[str, int]] size: Optional[dict[str, int]]
size_divisor: Optional[int] size_divisor: Optional[int]
resample: Optional["PILImageResampling"] resample: Optional["PILImageResampling"]
do_rescale: Optional[bool] do_rescale: Optional[bool]
rescale_factor: Optional[float] rescale_factor: Optional[float]
do_normalize: Optional[bool] do_normalize: Optional[bool]
image_mean: Optional[Union[float, List[float]]] image_mean: Optional[Union[float, list[float]]]
image_std: Optional[Union[float, List[float]]] image_std: Optional[Union[float, list[float]]]
do_pad: Optional[bool] do_pad: Optional[bool]
do_center_crop: Optional[bool] do_center_crop: Optional[bool]
data_format: Optional[ChannelDimension] data_format: Optional[ChannelDimension]
@ -280,7 +279,7 @@ class AudioKwargs(TypedDict, total=False):
""" """
sampling_rate: Optional[int] sampling_rate: Optional[int]
raw_speech: Optional[Union["np.ndarray", List[float], List["np.ndarray"], List[List[float]]]] raw_speech: Optional[Union["np.ndarray", list[float], list["np.ndarray"], list[list[float]]]]
padding: Optional[Union[bool, str, PaddingStrategy]] padding: Optional[Union[bool, str, PaddingStrategy]]
max_length: Optional[int] max_length: Optional[int]
truncation: Optional[bool] truncation: Optional[bool]
@ -379,8 +378,8 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
This functionality is only available for chat templates that support it via the `{% generation %}` keyword. This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
""" """
tools: Optional[List[Dict]] = None tools: Optional[list[dict]] = None
documents: Optional[List[Dict[str, str]]] = None documents: Optional[list[dict[str, str]]] = None
add_generation_prompt: Optional[bool] = False add_generation_prompt: Optional[bool] = False
continue_final_message: Optional[bool] = False continue_final_message: Optional[bool] = False
return_assistant_tokens_mask: Optional[bool] = False return_assistant_tokens_mask: Optional[bool] = False
@ -435,12 +434,12 @@ class ProcessorMixin(PushToHubMixin):
attributes = ["feature_extractor", "tokenizer"] attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"] optional_attributes = ["chat_template"]
optional_call_args: List[str] = [] optional_call_args: list[str] = []
# Names need to be attr_class for attr in attributes # Names need to be attr_class for attr in attributes
feature_extractor_class = None feature_extractor_class = None
tokenizer_class = None tokenizer_class = None
_auto_class = None _auto_class = None
valid_kwargs: List[str] = [] valid_kwargs: list[str] = []
# args have to match the attributes class attribute # args have to match the attributes class attribute
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -481,7 +480,7 @@ class ProcessorMixin(PushToHubMixin):
setattr(self, attribute_name, arg) setattr(self, attribute_name, arg)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """
Serializes this instance to a Python dictionary. Serializes this instance to a Python dictionary.
@ -659,7 +658,7 @@ class ProcessorMixin(PushToHubMixin):
@classmethod @classmethod
def get_processor_dict( def get_processor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`. processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`.
@ -764,13 +763,13 @@ class ProcessorMixin(PushToHubMixin):
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
) )
except EnvironmentError: except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
raise raise
except Exception: except Exception:
# For any other exception, we throw a generic error. # For any other exception, we throw a generic error.
raise EnvironmentError( raise OSError(
f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load" f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the" " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
@ -779,11 +778,11 @@ class ProcessorMixin(PushToHubMixin):
# Add chat template as kwarg before returning because most models don't have processor config # Add chat template as kwarg before returning because most models don't have processor config
if resolved_raw_chat_template_file is not None: if resolved_raw_chat_template_file is not None:
with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: with open(resolved_raw_chat_template_file, encoding="utf-8") as reader:
chat_template = reader.read() chat_template = reader.read()
kwargs["chat_template"] = chat_template kwargs["chat_template"] = chat_template
elif resolved_chat_template_file is not None: elif resolved_chat_template_file is not None:
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: with open(resolved_chat_template_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
chat_template = json.loads(text)["chat_template"] chat_template = json.loads(text)["chat_template"]
kwargs["chat_template"] = chat_template kwargs["chat_template"] = chat_template
@ -801,14 +800,12 @@ class ProcessorMixin(PushToHubMixin):
try: try:
# Load processor dict # Load processor dict
with open(resolved_processor_file, "r", encoding="utf-8") as reader: with open(resolved_processor_file, encoding="utf-8") as reader:
text = reader.read() text = reader.read()
processor_dict = json.loads(text) processor_dict = json.loads(text)
except json.JSONDecodeError: except json.JSONDecodeError:
raise EnvironmentError( raise OSError(f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file.")
f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
)
if is_local: if is_local:
logger.info(f"loading configuration file {resolved_processor_file}") logger.info(f"loading configuration file {resolved_processor_file}")
@ -837,7 +834,7 @@ class ProcessorMixin(PushToHubMixin):
return processor_dict, kwargs return processor_dict, kwargs
@classmethod @classmethod
def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
""" """
Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters. Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters.
@ -882,9 +879,9 @@ class ProcessorMixin(PushToHubMixin):
def _merge_kwargs( def _merge_kwargs(
self, self,
ModelProcessorKwargs: ProcessingKwargs, ModelProcessorKwargs: ProcessingKwargs,
tokenizer_init_kwargs: Optional[Dict] = None, tokenizer_init_kwargs: Optional[dict] = None,
**kwargs, **kwargs,
) -> Dict[str, Dict]: ) -> dict[str, dict]:
""" """
Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance. Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance.
The order of operations is as follows: The order of operations is as follows:
@ -1236,10 +1233,10 @@ class ProcessorMixin(PushToHubMixin):
def _process_messages_for_chat_template( def _process_messages_for_chat_template(
self, self,
conversation: List[List[Dict[str, str]]], conversation: list[list[dict[str, str]]],
batch_images: List[ImageInput], batch_images: list[ImageInput],
batch_videos: List[VideoInput], batch_videos: list[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]], batch_video_metadata: list[list[dict[str, any]]],
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate], **chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
): ):
""" """
@ -1270,7 +1267,7 @@ class ProcessorMixin(PushToHubMixin):
def apply_chat_template( def apply_chat_template(
self, self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
**kwargs: Unpack[AllKwargsForChatTemplate], **kwargs: Unpack[AllKwargsForChatTemplate],
) -> str: ) -> str:

View File

@ -950,11 +950,11 @@ def metrics_format(self, metrics: dict[str, float]) -> dict[str, float]:
metrics_copy = metrics.copy() metrics_copy = metrics.copy()
for k, v in metrics_copy.items(): for k, v in metrics_copy.items():
if "_mem_" in k: if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB" metrics_copy[k] = f"{v >> 20}MB"
elif "_runtime" in k: elif "_runtime" in k:
metrics_copy[k] = _secs2timedelta(v) metrics_copy[k] = _secs2timedelta(v)
elif k == "total_flos": elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF" metrics_copy[k] = f"{int(v) >> 30}GF"
elif isinstance(metrics_copy[k], float): elif isinstance(metrics_copy[k], float):
metrics_copy[k] = round(v, 4) metrics_copy[k] = round(v, 4)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import io
import json import json
import math import math
import os import os
@ -22,7 +21,7 @@ from dataclasses import asdict, dataclass, field, fields
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from huggingface_hub import get_full_repo_name from huggingface_hub import get_full_repo_name
from packaging import version from packaging import version
@ -1138,7 +1137,7 @@ class TrainingArguments:
) )
}, },
) )
debug: Union[str, List[DebugOption]] = field( debug: Union[str, list[DebugOption]] = field(
default="", default="",
metadata={ metadata={
"help": ( "help": (
@ -1198,7 +1197,7 @@ class TrainingArguments:
remove_unused_columns: Optional[bool] = field( remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
) )
label_names: Optional[List[str]] = field( label_names: Optional[list[str]] = field(
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
) )
load_best_model_at_end: Optional[bool] = field( load_best_model_at_end: Optional[bool] = field(
@ -1225,7 +1224,7 @@ class TrainingArguments:
) )
}, },
) )
fsdp: Optional[Union[List[FSDPOption], str]] = field( fsdp: Optional[Union[list[FSDPOption], str]] = field(
default="", default="",
metadata={ metadata={
"help": ( "help": (
@ -1318,7 +1317,7 @@ class TrainingArguments:
default="length", default="length",
metadata={"help": "Column name with precomputed lengths to use when grouping by length."}, metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
) )
report_to: Union[None, str, List[str]] = field( report_to: Union[None, str, list[str]] = field(
default=None, metadata={"help": "The list of integrations to report the results and logs to."} default=None, metadata={"help": "The list of integrations to report the results and logs to."}
) )
ddp_find_unused_parameters: Optional[bool] = field( ddp_find_unused_parameters: Optional[bool] = field(
@ -1406,7 +1405,7 @@ class TrainingArguments:
"help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead." "help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead."
}, },
) )
include_for_metrics: List[str] = field( include_for_metrics: list[str] = field(
default_factory=list, default_factory=list,
metadata={ metadata={
"help": "List of strings to specify additional data to include in the `compute_metrics` function." "help": "List of strings to specify additional data to include in the `compute_metrics` function."
@ -1534,7 +1533,7 @@ class TrainingArguments:
}, },
) )
optim_target_modules: Union[None, str, List[str]] = field( optim_target_modules: Union[None, str, list[str]] = field(
default=None, default=None,
metadata={ metadata={
"help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment." "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
@ -1940,7 +1939,7 @@ class TrainingArguments:
if isinstance(self.fsdp_config, str): if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0: if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f: with open(self.fsdp_config, encoding="utf-8") as f:
self.fsdp_config = json.load(f) self.fsdp_config = json.load(f)
for k in list(self.fsdp_config.keys()): for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"): if k.startswith("fsdp_"):
@ -2546,7 +2545,7 @@ class TrainingArguments:
) )
return warmup_steps return warmup_steps
def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: def _dict_torch_dtype_to_str(self, d: dict[str, Any]) -> None:
""" """
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
@ -2586,7 +2585,7 @@ class TrainingArguments:
""" """
return json.dumps(self.to_dict(), indent=2) return json.dumps(self.to_dict(), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]: def to_sanitized_dict(self) -> dict[str, Any]:
""" """
Sanitized serialization to use with TensorBoards hparams Sanitized serialization to use with TensorBoards hparams
""" """
@ -2829,7 +2828,7 @@ class TrainingArguments:
self, self,
strategy: Union[str, IntervalStrategy] = "steps", strategy: Union[str, IntervalStrategy] = "steps",
steps: int = 500, steps: int = 500,
report_to: Union[str, List[str]] = "none", report_to: Union[str, list[str]] = "none",
level: str = "passive", level: str = "passive",
first_step: bool = False, first_step: bool = False,
nan_inf_filter: bool = False, nan_inf_filter: bool = False,