Update deprecated Jax calls (#35919)

* Remove deprecated arguments for jax.numpy.clip.

* Remove deprecated arguments for jax.numpy.clip.

* Update jax version to 0.4.27 to 0.4.38.

* Avoid use of deprecated xla_bridge.get_backend().platform

Co-authored-by: Jake Vanderplas <jakevdp@google.com>

---------

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
rasmi 2025-03-20 06:51:51 -04:00 committed by GitHub
parent 1ddb64937c
commit f0d5b2ff04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 13 additions and 13 deletions

View File

@ -121,8 +121,8 @@ _deps = [
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
"jax>=0.4.1,<=0.4.13",
"jaxlib>=0.4.1,<=0.4.13",
"jax>=0.4.27,<=0.4.38",
"jaxlib>=0.4.27,<=0.4.38",
"jieba",
"jinja2>=3.1.0",
"kenlm",

View File

@ -129,7 +129,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
flax_version = flax.__version__
jax_version = jax.__version__
jaxlib_version = jaxlib.__version__
jax_backend = jax.lib.xla_bridge.get_backend().platform
jax_backend = jax.default_backend()
info = {
"`transformers` version": version,

View File

@ -28,8 +28,8 @@ deps = {
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.4.1,<=0.4.13",
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
"jax": "jax>=0.4.27,<=0.4.38",
"jaxlib": "jaxlib>=0.4.27,<=0.4.38",
"jieba": "jieba",
"jinja2": "jinja2>=3.1.0",
"kenlm": "kenlm",

View File

@ -387,7 +387,7 @@ class FlaxLongT5Attention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -398,7 +398,7 @@ class FlaxLongT5Attention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
@ -672,7 +672,7 @@ class FlaxLongT5LocalAttention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -683,7 +683,7 @@ class FlaxLongT5LocalAttention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
@ -895,7 +895,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -906,7 +906,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

View File

@ -247,7 +247,7 @@ class FlaxT5Attention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -258,7 +258,7 @@ class FlaxT5Attention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)