[SwitchTransformers] Remove unused module (#25427)

* remove unused module

* remove old feed_forward_proj

* fixup
This commit is contained in:
Arthur 2023-08-17 17:03:41 +02:00 committed by GitHub
parent d6bf08f7f6
commit 5347d00092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 46 deletions

View File

@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig):
router_z_loss_coef=0.001,
router_aux_loss_coef=0.001,
initializer_factor=1.0,
feed_forward_proj="relu",
dense_act_fn="relu",
is_encoder_decoder=True,
add_router_probs=False,
use_cache=True,
@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig):
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
self.add_router_probs = add_router_probs
self.router_z_loss_coef = router_z_loss_coef
self.router_aux_loss_coef = router_aux_loss_coef
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
self.dense_act_fn = dense_act_fn
super().__init__(
pad_token_id=pad_token_id,

View File

@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module):
return hidden_states
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class SwitchTransformersDenseGatedActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class SwitchTransformersSparseMLP(nn.Module):
r"""
Implementation of the Switch Transformers Sparse MLP module.
@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, SwitchTransformersDenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
module.wi_1.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, SwitchTransformersAttention):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136