Update modeling_flax_wav2vec2.py (#13680)

conv kernel_size to Tuple,
Flax Version 0.3.5 breaking change, https://github.com/google/flax/releases/tag/v0.3.5
This commit is contained in:
Kamal Raj 2021-09-22 03:06:13 +05:30 committed by GitHub
parent d16bec9530
commit 8565d38f30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,7 +286,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id],
kernel_size=self.config.conv_kernel[self.layer_id],
kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
@ -310,7 +310,7 @@ class FlaxConvWithWeightNorm(nn.Module):
def setup(self):
self.conv = nn.Conv(
features=self.config.hidden_size,
kernel_size=self.config.num_conv_pos_embeddings,
kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups,
@ -319,12 +319,12 @@ class FlaxConvWithWeightNorm(nn.Module):
weight_shape = (
self.conv.features,
self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size,
self.conv.kernel_size[0],
)
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size // 2
self.prev_padding = self.conv.kernel_size[0] // 2
def _get_normed_weights(self):
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]