mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00
Improved keras imports (#24448)
* An end to accursed version-specific imports * No more K.is_keras_tensor() either * Update dependency tables * Use a cleaner call context function getter * Add a cap to <2.14 * Add cap to examples requirements too
This commit is contained in:
parent
1e9da2b0a6
commit
8e164c5400
@ -1,4 +1,4 @@
|
|||||||
tensorflow<2.13
|
tensorflow<2.14
|
||||||
tensorboard
|
tensorboard
|
||||||
scikit-learn
|
scikit-learn
|
||||||
seqeval
|
seqeval
|
||||||
|
6
setup.py
6
setup.py
@ -168,9 +168,9 @@ _deps = [
|
|||||||
"sudachipy>=0.6.6",
|
"sudachipy>=0.6.6",
|
||||||
"sudachidict_core>=20220729",
|
"sudachidict_core>=20220729",
|
||||||
# TensorFlow pin. When changing this value, update examples/tensorflow/_tests_requirements.txt accordingly
|
# TensorFlow pin. When changing this value, update examples/tensorflow/_tests_requirements.txt accordingly
|
||||||
"tensorflow-cpu>=2.4,<2.13",
|
"tensorflow-cpu>=2.6,<2.14",
|
||||||
"tensorflow>=2.4,<2.13",
|
"tensorflow>=2.6,<2.14",
|
||||||
"tensorflow-text<2.13",
|
"tensorflow-text<2.14",
|
||||||
"tf2onnx",
|
"tf2onnx",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
"timm",
|
"timm",
|
||||||
|
@ -72,9 +72,9 @@ deps = {
|
|||||||
"starlette": "starlette",
|
"starlette": "starlette",
|
||||||
"sudachipy": "sudachipy>=0.6.6",
|
"sudachipy": "sudachipy>=0.6.6",
|
||||||
"sudachidict_core": "sudachidict_core>=20220729",
|
"sudachidict_core": "sudachidict_core>=20220729",
|
||||||
"tensorflow-cpu": "tensorflow-cpu>=2.4,<2.13",
|
"tensorflow-cpu": "tensorflow-cpu>=2.6,<2.14",
|
||||||
"tensorflow": "tensorflow>=2.4,<2.13",
|
"tensorflow": "tensorflow>=2.6,<2.14",
|
||||||
"tensorflow-text": "tensorflow-text<2.13",
|
"tensorflow-text": "tensorflow-text<2.14",
|
||||||
"tf2onnx": "tf2onnx",
|
"tf2onnx": "tf2onnx",
|
||||||
"timeout-decorator": "timeout-decorator",
|
"timeout-decorator": "timeout-decorator",
|
||||||
"timm": "timm",
|
"timm": "timm",
|
||||||
|
@ -251,12 +251,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||||
safetensors archive created with the safe_open() function."""
|
safetensors archive created with the safe_open() function."""
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from packaging.version import parse
|
|
||||||
|
|
||||||
if parse(tf.__version__) >= parse("2.11.0"):
|
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
else:
|
|
||||||
from tensorflow.python.keras import backend as K
|
|
||||||
|
|
||||||
if tf_inputs is None:
|
if tf_inputs is None:
|
||||||
tf_inputs = tf_model.dummy_inputs
|
tf_inputs = tf_model.dummy_inputs
|
||||||
|
@ -33,7 +33,9 @@ import h5py
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from huggingface_hub import Repository, list_repo_files
|
from huggingface_hub import Repository, list_repo_files
|
||||||
|
from keras import backend as K
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
|
from tensorflow.python.util.keras_deps import get_call_context_function
|
||||||
|
|
||||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||||
from .activations_tf import get_tf_activation
|
from .activations_tf import get_tf_activation
|
||||||
@ -71,20 +73,6 @@ from .utils import (
|
|||||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
|
||||||
|
|
||||||
if parse(tf.__version__).minor >= 13:
|
|
||||||
from keras import backend as K
|
|
||||||
from keras.__internal__ import KerasTensor
|
|
||||||
from keras.src.engine.base_layer_utils import call_context
|
|
||||||
elif parse(tf.__version__).minor >= 11:
|
|
||||||
from keras import backend as K
|
|
||||||
from keras.engine.base_layer_utils import call_context
|
|
||||||
from keras.engine.keras_tensor import KerasTensor
|
|
||||||
else:
|
|
||||||
from tensorflow.python.keras import backend as K
|
|
||||||
from tensorflow.python.keras.engine.base_layer_utils import call_context
|
|
||||||
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.tensorflow import save_file as safe_save_file
|
from safetensors.tensorflow import save_file as safe_save_file
|
||||||
@ -99,13 +87,10 @@ tf_logger = tf.get_logger()
|
|||||||
TFModelInputType = Union[
|
TFModelInputType = Union[
|
||||||
List[tf.Tensor],
|
List[tf.Tensor],
|
||||||
List[np.ndarray],
|
List[np.ndarray],
|
||||||
List[KerasTensor],
|
|
||||||
Dict[str, tf.Tensor],
|
Dict[str, tf.Tensor],
|
||||||
Dict[str, np.ndarray],
|
Dict[str, np.ndarray],
|
||||||
Dict[str, KerasTensor],
|
|
||||||
tf.Tensor,
|
tf.Tensor,
|
||||||
np.ndarray,
|
np.ndarray,
|
||||||
KerasTensor,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -472,7 +457,7 @@ def input_processing(func, config, **kwargs):
|
|||||||
main_input_name = parameter_names[0]
|
main_input_name = parameter_names[0]
|
||||||
main_input = kwargs.pop(main_input_name, None)
|
main_input = kwargs.pop(main_input_name, None)
|
||||||
output = {}
|
output = {}
|
||||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
|
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||||
|
|
||||||
if "inputs" in kwargs["kwargs_call"]:
|
if "inputs" in kwargs["kwargs_call"]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -511,7 +496,7 @@ def input_processing(func, config, **kwargs):
|
|||||||
kwargs.pop("kwargs_call")
|
kwargs.pop("kwargs_call")
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, allowed_types) or v is None:
|
if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
|
||||||
output[k] = v
|
output[k] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
@ -564,7 +549,7 @@ def input_processing(func, config, **kwargs):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||||
else:
|
else:
|
||||||
if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None:
|
if tf.is_tensor(main_input) or main_input is None:
|
||||||
output[main_input_name] = main_input
|
output[main_input_name] = main_input
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1142,6 +1127,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
return "tf"
|
return "tf"
|
||||||
|
|
||||||
def build(self, input_shape=None):
|
def build(self, input_shape=None):
|
||||||
|
call_context = get_call_context_function()
|
||||||
if self.built or call_context().in_call:
|
if self.built or call_context().in_call:
|
||||||
self.built = True
|
self.built = True
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user