Improved keras imports (#24448)

* An end to accursed version-specific imports

* No more K.is_keras_tensor() either

* Update dependency tables

* Use a cleaner call context function getter

* Add a cap to <2.14

* Add cap to examples requirements too
This commit is contained in:
Matt 2023-06-23 19:09:34 +01:00 committed by GitHub
parent 1e9da2b0a6
commit 8e164c5400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 14 additions and 33 deletions

View File

@ -1,4 +1,4 @@
tensorflow<2.13
tensorflow<2.14
tensorboard
scikit-learn
seqeval

View File

@ -168,9 +168,9 @@ _deps = [
"sudachipy>=0.6.6",
"sudachidict_core>=20220729",
# TensorFlow pin. When changing this value, update examples/tensorflow/_tests_requirements.txt accordingly
"tensorflow-cpu>=2.4,<2.13",
"tensorflow>=2.4,<2.13",
"tensorflow-text<2.13",
"tensorflow-cpu>=2.6,<2.14",
"tensorflow>=2.6,<2.14",
"tensorflow-text<2.14",
"tf2onnx",
"timeout-decorator",
"timm",

View File

@ -72,9 +72,9 @@ deps = {
"starlette": "starlette",
"sudachipy": "sudachipy>=0.6.6",
"sudachidict_core": "sudachidict_core>=20220729",
"tensorflow-cpu": "tensorflow-cpu>=2.4,<2.13",
"tensorflow": "tensorflow>=2.4,<2.13",
"tensorflow-text": "tensorflow-text<2.13",
"tensorflow-cpu": "tensorflow-cpu>=2.6,<2.14",
"tensorflow": "tensorflow>=2.6,<2.14",
"tensorflow-text": "tensorflow-text<2.14",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"timm": "timm",

View File

@ -251,12 +251,7 @@ def load_pytorch_state_dict_in_tf2_model(
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
import tensorflow as tf
from packaging.version import parse
if parse(tf.__version__) >= parse("2.11.0"):
from keras import backend as K
else:
from tensorflow.python.keras import backend as K
from keras import backend as K
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs

View File

@ -33,7 +33,9 @@ import h5py
import numpy as np
import tensorflow as tf
from huggingface_hub import Repository, list_repo_files
from keras import backend as K
from packaging.version import parse
from tensorflow.python.util.keras_deps import get_call_context_function
from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
@ -71,20 +73,6 @@ from .utils import (
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.src.engine.base_layer_utils import call_context
elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine.base_layer_utils import call_context
from keras.engine.keras_tensor import KerasTensor
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.base_layer_utils import call_context
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
if is_safetensors_available():
from safetensors import safe_open
from safetensors.tensorflow import save_file as safe_save_file
@ -99,13 +87,10 @@ tf_logger = tf.get_logger()
TFModelInputType = Union[
List[tf.Tensor],
List[np.ndarray],
List[KerasTensor],
Dict[str, tf.Tensor],
Dict[str, np.ndarray],
Dict[str, KerasTensor],
tf.Tensor,
np.ndarray,
KerasTensor,
]
@ -472,7 +457,7 @@ def input_processing(func, config, **kwargs):
main_input_name = parameter_names[0]
main_input = kwargs.pop(main_input_name, None)
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
@ -511,7 +496,7 @@ def input_processing(func, config, **kwargs):
kwargs.pop("kwargs_call")
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
@ -564,7 +549,7 @@ def input_processing(func, config, **kwargs):
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(main_input, (tf.Tensor, KerasTensor)) or main_input is None:
if tf.is_tensor(main_input) or main_input is None:
output[main_input_name] = main_input
else:
raise ValueError(
@ -1142,6 +1127,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return "tf"
def build(self, input_shape=None):
call_context = get_call_context_function()
if self.built or call_context().in_call:
self.built = True
else: