Use new parametrization based weight norm if available (#24030)

* Use new parametrization based weight norm if available

See https://github.com/pytorch/pytorch/pull/103001

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

* handle copies

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

* black

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

---------

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
This commit is contained in:
Edward Z. Yang 2023-06-06 13:34:57 -04:00 committed by GitHub
parent 4a55e47877
commit bc9ecef942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 14 deletions

View File

@ -271,15 +271,19 @@ class HubertPositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -401,15 +401,19 @@ class SpeechT5PositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -306,15 +306,19 @@ class UniSpeechPositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -320,15 +320,19 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = UniSpeechSatSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -376,15 +376,19 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -356,15 +356,19 @@ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@ -279,15 +279,19 @@ class WavLMPositionalConvEmbedding(nn.Module):
groups=config.num_conv_pos_embedding_groups,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]