mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Update deprecated Jax calls (#35919)
* Remove deprecated arguments for jax.numpy.clip. * Remove deprecated arguments for jax.numpy.clip. * Update jax version to 0.4.27 to 0.4.38. * Avoid use of deprecated xla_bridge.get_backend().platform Co-authored-by: Jake Vanderplas <jakevdp@google.com> --------- Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
parent
1ddb64937c
commit
f0d5b2ff04
4
setup.py
4
setup.py
@ -121,8 +121,8 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.4.1,<=0.4.13",
|
||||
"jaxlib>=0.4.1,<=0.4.13",
|
||||
"jax>=0.4.27,<=0.4.38",
|
||||
"jaxlib>=0.4.27,<=0.4.38",
|
||||
"jieba",
|
||||
"jinja2>=3.1.0",
|
||||
"kenlm",
|
||||
|
@ -129,7 +129,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
|
||||
flax_version = flax.__version__
|
||||
jax_version = jax.__version__
|
||||
jaxlib_version = jaxlib.__version__
|
||||
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
||||
jax_backend = jax.default_backend()
|
||||
|
||||
info = {
|
||||
"`transformers` version": version,
|
||||
|
@ -28,8 +28,8 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.4.1,<=0.4.13",
|
||||
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
|
||||
"jax": "jax>=0.4.27,<=0.4.38",
|
||||
"jaxlib": "jaxlib>=0.4.27,<=0.4.38",
|
||||
"jieba": "jieba",
|
||||
"jinja2": "jinja2>=3.1.0",
|
||||
"kenlm": "kenlm",
|
||||
|
@ -387,7 +387,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -398,7 +398,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
@ -672,7 +672,7 @@ class FlaxLongT5LocalAttention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -683,7 +683,7 @@ class FlaxLongT5LocalAttention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
@ -895,7 +895,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -906,7 +906,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
|
@ -247,7 +247,7 @@ class FlaxT5Attention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -258,7 +258,7 @@ class FlaxT5Attention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user