mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
fix type check (#12638)
This commit is contained in:
parent
2dd9440d08
commit
f8f9a679a0
@ -1715,10 +1715,10 @@ def is_tensor(x):
|
||||
return True
|
||||
|
||||
if is_flax_available():
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
import jax.numpy as jnp
|
||||
from jax.core import Tracer
|
||||
|
||||
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
|
||||
if isinstance(x, (jnp.ndarray, Tracer)):
|
||||
return True
|
||||
|
||||
return isinstance(x, np.ndarray)
|
||||
|
Loading…
Reference in New Issue
Block a user