diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index d5762337b50..a8277588ffd 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -38,10 +38,6 @@ from .import_utils import ( ) -if is_flax_available(): - import jax.numpy as jnp - - class cached_property(property): """ Descriptor that mimics @property but caches output in member variable. @@ -624,6 +620,8 @@ def transpose(array, axes=None): return tf.transpose(array, perm=axes) elif is_jax_tensor(array): + import jax.numpy as jnp + return jnp.transpose(array, axes=axes) else: raise ValueError(f"Type not supported for transpose: {type(array)}.") @@ -643,6 +641,8 @@ def reshape(array, newshape): return tf.reshape(array, newshape) elif is_jax_tensor(array): + import jax.numpy as jnp + return jnp.reshape(array, newshape) else: raise ValueError(f"Type not supported for reshape: {type(array)}.") @@ -662,6 +662,8 @@ def squeeze(array, axis=None): return tf.squeeze(array, axis=axis) elif is_jax_tensor(array): + import jax.numpy as jnp + return jnp.squeeze(array, axis=axis) else: raise ValueError(f"Type not supported for squeeze: {type(array)}.") @@ -681,6 +683,8 @@ def expand_dims(array, axis): return tf.expand_dims(array, axis=axis) elif is_jax_tensor(array): + import jax.numpy as jnp + return jnp.expand_dims(array, axis=axis) else: raise ValueError(f"Type not supported for expand_dims: {type(array)}.")