mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Improved keras imports (#24448)
* An end to accursed version-specific imports * No more K.is_keras_tensor() either * Update dependency tables * Use a cleaner call context function getter * Add a cap to <2.14 * Add cap to examples requirements too
This commit is contained in:
parent
1e9da2b0a6
commit
8e164c5400
@ -1,4 +1,4 @@
|
||||
tensorflow<2.13
|
||||
tensorflow<2.14
|
||||
tensorboard
|
||||
scikit-learn
|
||||
seqeval
|
||||
|
6
setup.py
6
setup.py
@ -168,9 +168,9 @@ _deps = [
|
||||
"sudachipy>=0.6.6",
|
||||
"sudachidict_core>=20220729",
|
||||
# TensorFlow pin. When changing this value, update examples/tensorflow/_tests_requirements.txt accordingly
|
||||
"tensorflow-cpu>=2.4,<2.13",
|
||||
"tensorflow>=2.4,<2.13",
|
||||
"tensorflow-text<2.13",
|
||||
"tensorflow-cpu>=2.6,<2.14",
|
||||
"tensorflow>=2.6,<2.14",
|
||||
"tensorflow-text<2.14",
|
||||
"tf2onnx",
|
||||
"timeout-decorator",
|
||||
"timm",
|
||||
|
@ -72,9 +72,9 @@ deps = {
|
||||
"starlette": "starlette",
|
||||
"sudachipy": "sudachipy>=0.6.6",
|
||||
"sudachidict_core": "sudachidict_core>=20220729",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.4,<2.13",
|
||||
"tensorflow": "tensorflow>=2.4,<2.13",
|
||||
"tensorflow-text": "tensorflow-text<2.13",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.6,<2.14",
|
||||
"tensorflow": "tensorflow>=2.6,<2.14",
|
||||
"tensorflow-text": "tensorflow-text<2.14",
|
||||
"tf2onnx": "tf2onnx",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"timm": "timm",
|
||||
|
@ -251,12 +251,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||
safetensors archive created with the safe_open() function."""
|
||||
import tensorflow as tf
|
||||
from packaging.version import parse
|
||||
|
||||
if parse(tf.__version__) >= parse("2.11.0"):
|
||||
from keras import backend as K
|
||||
else:
|
||||
from tensorflow.python.keras import backend as K
|
||||
from keras import backend as K
|
||||
|
||||
if tf_inputs is None:
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
|
@ -33,7 +33,9 @@ import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from huggingface_hub import Repository, list_repo_files
|
||||
from keras import backend as K
|
||||
from packaging.version import parse
|
||||
from tensorflow.python.util.keras_deps import get_call_context_function
|
||||
|
||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||
from .activations_tf import get_tf_activation
|
||||
@ -71,20 +73,6 @@ from .utils import (
|
||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
|
||||
|
||||
if parse(tf.__version__).minor >= 13:
|
||||
from keras import backend as K
|
||||
from keras.__internal__ import KerasTensor
|
||||
from keras.src.engine.base_layer_utils import call_context
|
||||
elif parse(tf.__version__).minor >= 11:
|
||||
from keras import backend as K
|
||||
from keras.engine.base_layer_utils import call_context
|
||||
from keras.engine.keras_tensor import KerasTensor
|
||||
else:
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine.base_layer_utils import call_context
|
||||
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.tensorflow import save_file as safe_save_file
|
||||
@ -99,13 +87,10 @@ tf_logger = tf.get_logger()
|
||||
TFModelInputType = Union[
|
||||
List[tf.Tensor],
|
||||
List[np.ndarray],
|
||||
List[KerasTensor],
|
||||
Dict[str, tf.Tensor],
|
||||
Dict[str, np.ndarray],
|
||||
Dict[str, KerasTensor],
|
||||
tf.Tensor,
|
||||
np.ndarray,
|
||||
KerasTensor,
|
||||
]
|
||||
|
||||
|
||||
@ -472,7 +457,7 @@ def input_processing(func, config, **kwargs):
|
||||
main_input_name = parameter_names[0]
|
||||
main_input = kwargs.pop(main_input_name, None)
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
|
||||
if "inputs" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
@ -511,7 +496,7 @@ def input_processing(func, config, **kwargs):
|
||||
kwargs.pop("kwargs_call")
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
|
||||
output[k] = v
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
@ -564,7 +549,7 @@ def input_processing(func, config, **kwargs):
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
else:
|
||||
if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None:
|
||||
if tf.is_tensor(main_input) or main_input is None:
|
||||
output[main_input_name] = main_input
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -1142,6 +1127,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return "tf"
|
||||
|
||||
def build(self, input_shape=None):
|
||||
call_context = get_call_context_function()
|
||||
if self.built or call_context().in_call:
|
||||
self.built = True
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user