refactor: remove custom BarkLayerNorm (#39003)

`nn.LayerNorm` supports `bias=False` since Pytorch 2.1
This commit is contained in:
Enno Hermann 2025-06-25 17:07:52 +02:00 committed by GitHub
parent 3c1d4dfbac
commit 3233e9b7c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -282,18 +282,6 @@ BARK_ATTENTION_CLASSES = {
}
class BarkLayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
def __init__(self, hidden_size, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, eps=1e-5)
class BarkMLP(nn.Module):
def __init__(self, config):
super().__init__()
@ -315,11 +303,10 @@ class BarkBlock(GradientCheckpointingLayer):
super().__init__()
if is_causal:
# if causal, uses handmade LayerNorm, so that the layerNorm bias is optional
# this handmade layerNorm is used to stick with Bark choice of leaving optional bias in
# AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
self.layernorm_1 = BarkLayerNorm(config.hidden_size, bias=config.bias)
self.layernorm_2 = BarkLayerNorm(config.hidden_size, bias=config.bias)
# if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
# in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
else:
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
@ -427,7 +414,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
self.gradient_checkpointing = False