mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Include Keras tensor in the allowed types (#14155)
* Include KerasTensor in allowed types - This allows propagating symbolic tensors through TFBert models and layers' call(), which allows converting the subclass models to functional models. * Style pass Co-authored-by: Sergio Valcarcel Macua <sergiov@graphcore.ai> Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
parent
f5ed19f57d
commit
919a964b8f
@ -27,6 +27,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import data_adapter
|
||||
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
@ -52,7 +53,15 @@ logger = logging.get_logger(__name__)
|
||||
tf_logger = tf.get_logger()
|
||||
|
||||
TFModelInputType = Union[
|
||||
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
|
||||
List[tf.Tensor],
|
||||
List[np.ndarray],
|
||||
List[KerasTensor],
|
||||
Dict[str, tf.Tensor],
|
||||
Dict[str, np.ndarray],
|
||||
Dict[str, KerasTensor],
|
||||
tf.Tensor,
|
||||
np.ndarray,
|
||||
KerasTensor,
|
||||
]
|
||||
|
||||
|
||||
@ -348,7 +357,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
signature.pop("self", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray, KerasTensor)
|
||||
|
||||
if "inputs" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
@ -432,7 +441,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
else:
|
||||
if isinstance(input_ids, tf.Tensor) or input_ids is None:
|
||||
if isinstance(input_ids, (tf.Tensor, KerasTensor)) or input_ids is None:
|
||||
output[parameter_names[0]] = input_ids
|
||||
else:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user