Revert "Update deprecated Jax calls (#35919)" (#36880)

* Revert "Update deprecated Jax calls (#35919)"

This reverts commit f0d5b2ff04.

* Revert "Update deprecated Jax calls (#35919)"

This reverts commit f0d5b2ff04.

* udpate
This commit is contained in:
Arthur 2025-03-21 11:01:44 +01:00 committed by GitHub
parent 62116c967f
commit f19d018bff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 9 deletions

View File

@ -129,7 +129,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
flax_version = flax.__version__
jax_version = jax.__version__
jaxlib_version = jaxlib.__version__
jax_backend = jax.default_backend()
jax_backend = jax.lib.xla_bridge.get_backend().platform
info = {
"`transformers` version": version,

View File

@ -387,7 +387,7 @@ class FlaxLongT5Attention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, max=0)
relative_position = -jnp.clip(relative_position, a_max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -398,7 +398,7 @@ class FlaxLongT5Attention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
@ -672,7 +672,7 @@ class FlaxLongT5LocalAttention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, max=0)
relative_position = -jnp.clip(relative_position, a_max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -683,7 +683,7 @@ class FlaxLongT5LocalAttention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
@ -895,7 +895,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, max=0)
relative_position = -jnp.clip(relative_position, a_max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -906,7 +906,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

View File

@ -247,7 +247,7 @@ class FlaxT5Attention(nn.Module):
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, max=0)
relative_position = -jnp.clip(relative_position, a_max=0)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -258,7 +258,7 @@ class FlaxT5Attention(nn.Module):
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)