mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Update modeling_flax_wav2vec2.py (#13680)
conv kernel_size to Tuple, Flax Version 0.3.5 breaking change, https://github.com/google/flax/releases/tag/v0.3.5
This commit is contained in:
parent
d16bec9530
commit
8565d38f30
@ -286,7 +286,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
|
||||
|
||||
self.conv = nn.Conv(
|
||||
features=self.config.conv_dim[self.layer_id],
|
||||
kernel_size=self.config.conv_kernel[self.layer_id],
|
||||
kernel_size=(self.config.conv_kernel[self.layer_id],),
|
||||
strides=(self.config.conv_stride[self.layer_id],),
|
||||
use_bias=self.config.conv_bias,
|
||||
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
|
||||
@ -310,7 +310,7 @@ class FlaxConvWithWeightNorm(nn.Module):
|
||||
def setup(self):
|
||||
self.conv = nn.Conv(
|
||||
features=self.config.hidden_size,
|
||||
kernel_size=self.config.num_conv_pos_embeddings,
|
||||
kernel_size=(self.config.num_conv_pos_embeddings,),
|
||||
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
|
||||
padding="VALID",
|
||||
feature_group_count=self.config.num_conv_pos_embedding_groups,
|
||||
@ -319,12 +319,12 @@ class FlaxConvWithWeightNorm(nn.Module):
|
||||
weight_shape = (
|
||||
self.conv.features,
|
||||
self.conv.features // self.conv.feature_group_count,
|
||||
self.conv.kernel_size,
|
||||
self.conv.kernel_size[0],
|
||||
)
|
||||
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
|
||||
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
|
||||
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
|
||||
self.prev_padding = self.conv.kernel_size // 2
|
||||
self.prev_padding = self.conv.kernel_size[0] // 2
|
||||
|
||||
def _get_normed_weights(self):
|
||||
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
|
||||
|
Loading…
Reference in New Issue
Block a user