fix type check (#12638)

This commit is contained in:
Suraj Patil 2021-07-12 15:18:43 +05:30 committed by GitHub
parent 2dd9440d08
commit f8f9a679a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1715,10 +1715,10 @@ def is_tensor(x):
return True
if is_flax_available():
import jaxlib.xla_extension as jax_xla
import jax.numpy as jnp
from jax.core import Tracer
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
if isinstance(x, (jnp.ndarray, Tracer)):
return True
return isinstance(x, np.ndarray)