mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[SwitchTransformers
] Remove unused module (#25427)
* remove unused module * remove old feed_forward_proj * fixup
This commit is contained in:
parent
d6bf08f7f6
commit
5347d00092
@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig):
|
||||
router_z_loss_coef=0.001,
|
||||
router_aux_loss_coef=0.001,
|
||||
initializer_factor=1.0,
|
||||
feed_forward_proj="relu",
|
||||
dense_act_fn="relu",
|
||||
is_encoder_decoder=True,
|
||||
add_router_probs=False,
|
||||
use_cache=True,
|
||||
@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig):
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
self.use_cache = use_cache
|
||||
self.add_router_probs = add_router_probs
|
||||
|
||||
self.router_z_loss_coef = router_z_loss_coef
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
act_info = self.feed_forward_proj.split("-")
|
||||
self.dense_act_fn = act_info[-1]
|
||||
self.is_gated_act = act_info[0] == "gated"
|
||||
|
||||
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
|
||||
raise ValueError(
|
||||
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
|
||||
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
|
||||
"'gated-gelu' or 'relu'"
|
||||
)
|
||||
|
||||
# for backwards compatibility
|
||||
if feed_forward_proj == "gated-gelu":
|
||||
self.dense_act_fn = "gelu_new"
|
||||
self.dense_act_fn = dense_act_fn
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
|
||||
class SwitchTransformersDenseGatedActDense(nn.Module):
|
||||
def __init__(self, config: SwitchTransformersConfig):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = ACT2FN[config.dense_act_fn]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwitchTransformersSparseMLP(nn.Module):
|
||||
r"""
|
||||
Implementation of the Switch Transformers Sparse MLP module.
|
||||
@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, SwitchTransformersDenseGatedActDense):
|
||||
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
|
||||
module.wi_0.bias.data.zero_()
|
||||
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
|
||||
module.wi_1.bias.data.zero_()
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, "bias") and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, SwitchTransformersAttention):
|
||||
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
||||
|
Loading…
Reference in New Issue
Block a user