[StableLm] Add QK normalization and Parallel Residual Support (#29745)

* init: add StableLm 2 support

* add integration test for parallel residual and qk layernorm

* update(modeling): match qk norm naming for consistency with phi/persimmon

* fix(tests): run fwd/bwd on random init test model to jitter norm weights off identity

* `use_parallel_residual`: add copy pointer to `GPTNeoXLayer.forward`

* refactor: rename head states var in `StableLmLayerNormPerHead`

* tests: update test model and add generate check
This commit is contained in:
Jonathan Tow 2024-04-08 17:51:58 -04:00 committed by GitHub
parent 8c00b53eb0
commit 2f12e40822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 10 deletions

View File

@ -83,6 +83,11 @@ class StableLmConfig(PretrainedConfig):
is an experimental feature, subject to breaking API changes in future versions.
use_qkv_bias (`bool`, *optional*, defaults to `False`):
Whether or not the model should use bias for qkv layers.
qk_layernorm (`bool`, *optional*, defaults to `False`):
Whether or not to normalize, per head, the Queries and Keys after projecting the hidden states.
use_parallel_residual (`bool`, *optional*, defaults to `False`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after applying the MLP to the hidden states.
attention_dropout (`float`, *optional*, defaults to 0.0):
@ -123,6 +128,8 @@ class StableLmConfig(PretrainedConfig):
rope_theta=10_000,
rope_scaling=None,
use_qkv_bias=False,
qk_layernorm=False,
use_parallel_residual=False,
hidden_dropout=0.0,
attention_dropout=0.0,
partial_rotary_factor=0.25,
@ -146,6 +153,8 @@ class StableLmConfig(PretrainedConfig):
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.use_qkv_bias = use_qkv_bias
self.qk_layernorm = qk_layernorm
self.use_parallel_residual = use_parallel_residual
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.partial_rotary_factor = partial_rotary_factor

View File

@ -203,6 +203,21 @@ class StableLmMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class StableLmLayerNormPerHead(nn.Module):
def __init__(self, dim, num_heads, eps=1e-5, bias=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
def forward(self, hidden_states: torch.Tensor):
# Split along the num_heads axis to get per-head inputs
# [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
states_per_heads = torch.split(hidden_states, 1, dim=1)
# Normalize and merge the heads back together
return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
@ -250,6 +265,13 @@ class StableLmAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
self.k_layernorm = StableLmLayerNormPerHead(
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self._init_rope()
@ -300,6 +322,10 @@ class StableLmAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
@ -409,6 +435,10 @@ class StableLmSdpaAttention(StableLmAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
@ -513,6 +543,10 @@ class StableLmFlashAttention2(StableLmAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
@ -678,11 +712,14 @@ ATTENTION_CLASSES = {
class StableLmDecoderLayer(nn.Module):
def __init__(self, config: StableLmConfig, layer_idx: int):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.hidden_size = config.hidden_size
self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = StableLmMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = None
if not self.use_parallel_residual:
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(
@ -719,7 +756,7 @@ class StableLmDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -727,15 +764,22 @@ class StableLmDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + residual
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
if self.use_parallel_residual:
# x = x + attn(ln1(x)) + mlp(ln1(x))
# Fully Connected
mlp_output = self.mlp(hidden_states)
mlp_output = self.dropout(mlp_output)
hidden_states = residual + self_attn_output + mlp_output
else:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
residual = residual + self_attn_output
# Fully Connected
mlp_output = self.mlp(self.post_attention_layernorm(residual))
mlp_output = self.dropout(mlp_output)
hidden_states = residual + mlp_output
outputs = (hidden_states,)

View File

@ -483,6 +483,40 @@ class StableLmModelIntegrationTest(unittest.TestCase):
EXPECTED_TEXT_COMPLETION = """My favorite food has always been pizza, but lately Ive been craving something different. Ive been trying to eat healthier and Ive"""
self.assertEqual(text, EXPECTED_TEXT_COMPLETION)
@slow
def test_model_tiny_random_stablelm_2_logits(self):
# Check parallel residual and qk layernorm forward pass
input_ids = {"input_ids": torch.tensor([[510, 8588, 310, 1900, 9386]], dtype=torch.long, device=torch_device)}
model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2").to(torch_device)
model.eval()
output = model(**input_ids).logits
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.7196, -3.6099, -2.6877, -3.1973, -3.9344]]).to(torch_device)
self.assertTrue(torch.allclose(output.mean(dim=-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4))
# Expected logits sliced from [0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([2.8364, 5.3811, 5.1659, 7.5485, 4.3219, 6.3315, 1.3967, 6.9147, 3.9679, 6.4786, 5.9176, 3.3067, 5.2917, 0.1485, 3.9630, 7.9947,10.6727, 9.6757, 8.8772, 8.3527, 7.8445, 6.6025, 5.5786, 7.0985,6.1369, 3.4259, 1.9397, 4.6157, 4.8105, 3.1768]).to(torch_device) # fmt: skip
self.assertTrue(torch.allclose(output[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4))
@slow
def test_model_tiny_random_stablelm_2_generation(self):
# Check parallel residual and qk layernorm generation
tokenizer = AutoTokenizer.from_pretrained("stabilityai/tiny-random-stablelm-2")
model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2")
input_ids = tokenizer.encode(
"My favorite ride at the amusement park",
return_tensors="pt",
)
outputs = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
EXPECTED_TEXT_COMPLETION = """My favorite ride at the amusement park is the 2000-mile roller coaster. It's a thrilling ride filled with roller coast"""
self.assertEqual(text, EXPECTED_TEXT_COMPLETION)
@require_bitsandbytes
@slow
@require_flash_attn