Include Keras tensor in the allowed types (#14155)

* Include KerasTensor in allowed types

- This allows propagating symbolic tensors through TFBert models and layers' call(),
  which allows converting the subclass models to functional models.

* Style pass

Co-authored-by: Sergio Valcarcel Macua <sergiov@graphcore.ai>
Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
Sergio Valcarcel Macua 2021-10-26 15:08:59 +01:00 committed by GitHub
parent f5ed19f57d
commit 919a964b8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,6 +27,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
@ -52,7 +53,15 @@ logger = logging.get_logger(__name__)
tf_logger = tf.get_logger()
TFModelInputType = Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
List[tf.Tensor],
List[np.ndarray],
List[KerasTensor],
Dict[str, tf.Tensor],
Dict[str, np.ndarray],
Dict[str, KerasTensor],
tf.Tensor,
np.ndarray,
KerasTensor,
]
@ -348,7 +357,7 @@ def input_processing(func, config, input_ids, **kwargs):
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
@ -432,7 +441,7 @@ def input_processing(func, config, input_ids, **kwargs):
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(