mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Revert "Update deprecated Jax calls (#35919)" This reverts commitf0d5b2ff04
. * Revert "Update deprecated Jax calls (#35919)" This reverts commitf0d5b2ff04
. * udpate
This commit is contained in:
parent
62116c967f
commit
f19d018bff
@ -129,7 +129,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
|
||||
flax_version = flax.__version__
|
||||
jax_version = jax.__version__
|
||||
jaxlib_version = jaxlib.__version__
|
||||
jax_backend = jax.default_backend()
|
||||
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
||||
|
||||
info = {
|
||||
"`transformers` version": version,
|
||||
|
@ -387,7 +387,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -398,7 +398,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
@ -672,7 +672,7 @@ class FlaxLongT5LocalAttention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -683,7 +683,7 @@ class FlaxLongT5LocalAttention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
@ -895,7 +895,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -906,7 +906,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
|
@ -247,7 +247,7 @@ class FlaxT5Attention(nn.Module):
|
||||
relative_buckets += (relative_position > 0) * num_buckets
|
||||
relative_position = jnp.abs(relative_position)
|
||||
else:
|
||||
relative_position = -jnp.clip(relative_position, max=0)
|
||||
relative_position = -jnp.clip(relative_position, a_max=0)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -258,7 +258,7 @@ class FlaxT5Attention(nn.Module):
|
||||
relative_position_if_large = max_exact + (
|
||||
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
|
||||
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
|
||||
|
||||
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user