mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Use new parametrization based weight norm if available (#24030)
* Use new parametrization based weight norm if available See https://github.com/pytorch/pytorch/pull/103001 Signed-off-by: Edward Z. Yang <ezyang@meta.com> * handle copies Signed-off-by: Edward Z. Yang <ezyang@meta.com> * black Signed-off-by: Edward Z. Yang <ezyang@meta.com> --------- Signed-off-by: Edward Z. Yang <ezyang@meta.com>
This commit is contained in:
parent
4a55e47877
commit
bc9ecef942
@ -271,15 +271,19 @@ class HubertPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -401,15 +401,19 @@ class SpeechT5PositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -306,15 +306,19 @@ class UniSpeechPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -320,15 +320,19 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = UniSpeechSatSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -376,15 +376,19 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -356,15 +356,19 @@ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
@ -279,15 +279,19 @@ class WavLMPositionalConvEmbedding(nn.Module):
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
|
||||
weight_norm = nn.utils.weight_norm
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
weight_norm = nn.utils.parametrizations.weight_norm
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
Loading…
Reference in New Issue
Block a user