mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[StableLm
] Add QK normalization and Parallel Residual Support (#29745)
* init: add StableLm 2 support * add integration test for parallel residual and qk layernorm * update(modeling): match qk norm naming for consistency with phi/persimmon * fix(tests): run fwd/bwd on random init test model to jitter norm weights off identity * `use_parallel_residual`: add copy pointer to `GPTNeoXLayer.forward` * refactor: rename head states var in `StableLmLayerNormPerHead` * tests: update test model and add generate check
This commit is contained in:
parent
8c00b53eb0
commit
2f12e40822
@ -83,6 +83,11 @@ class StableLmConfig(PretrainedConfig):
|
||||
is an experimental feature, subject to breaking API changes in future versions.
|
||||
use_qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should use bias for qkv layers.
|
||||
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to normalize, per head, the Queries and Keys after projecting the hidden states.
|
||||
use_parallel_residual (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
|
||||
speedup at large scales.
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio after applying the MLP to the hidden states.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
@ -123,6 +128,8 @@ class StableLmConfig(PretrainedConfig):
|
||||
rope_theta=10_000,
|
||||
rope_scaling=None,
|
||||
use_qkv_bias=False,
|
||||
qk_layernorm=False,
|
||||
use_parallel_residual=False,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
partial_rotary_factor=0.25,
|
||||
@ -146,6 +153,8 @@ class StableLmConfig(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.use_qkv_bias = use_qkv_bias
|
||||
self.qk_layernorm = qk_layernorm
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
|
@ -203,6 +203,21 @@ class StableLmMLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class StableLmLayerNormPerHead(nn.Module):
|
||||
def __init__(self, dim, num_heads, eps=1e-5, bias=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
# Split along the num_heads axis to get per-head inputs
|
||||
# [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
|
||||
states_per_heads = torch.split(hidden_states, 1, dim=1)
|
||||
# Normalize and merge the heads back together
|
||||
return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
@ -250,6 +265,13 @@ class StableLmAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
|
||||
self.k_layernorm = StableLmLayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self._init_rope()
|
||||
|
||||
@ -300,6 +322,10 @@ class StableLmAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
@ -409,6 +435,10 @@ class StableLmSdpaAttention(StableLmAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
@ -513,6 +543,10 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.qk_layernorm:
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
@ -678,10 +712,13 @@ ATTENTION_CLASSES = {
|
||||
class StableLmDecoderLayer(nn.Module):
|
||||
def __init__(self, config: StableLmConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
self.mlp = StableLmMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = None
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
|
||||
@ -719,7 +756,7 @@ class StableLmDecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -727,15 +764,22 @@ class StableLmDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
|
||||
if self.use_parallel_residual:
|
||||
# x = x + attn(ln1(x)) + mlp(ln1(x))
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
mlp_output = self.dropout(mlp_output)
|
||||
hidden_states = residual + self_attn_output + mlp_output
|
||||
else:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
residual = residual + self_attn_output
|
||||
# Fully Connected
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(residual))
|
||||
mlp_output = self.dropout(mlp_output)
|
||||
hidden_states = residual + mlp_output
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
|
@ -483,6 +483,40 @@ class StableLmModelIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_TEXT_COMPLETION = """My favorite food has always been pizza, but lately I’ve been craving something different. I’ve been trying to eat healthier and I’ve"""
|
||||
self.assertEqual(text, EXPECTED_TEXT_COMPLETION)
|
||||
|
||||
@slow
|
||||
def test_model_tiny_random_stablelm_2_logits(self):
|
||||
# Check parallel residual and qk layernorm forward pass
|
||||
input_ids = {"input_ids": torch.tensor([[510, 8588, 310, 1900, 9386]], dtype=torch.long, device=torch_device)}
|
||||
|
||||
model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(**input_ids).logits
|
||||
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-2.7196, -3.6099, -2.6877, -3.1973, -3.9344]]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(output.mean(dim=-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4))
|
||||
|
||||
# Expected logits sliced from [0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([2.8364, 5.3811, 5.1659, 7.5485, 4.3219, 6.3315, 1.3967, 6.9147, 3.9679, 6.4786, 5.9176, 3.3067, 5.2917, 0.1485, 3.9630, 7.9947,10.6727, 9.6757, 8.8772, 8.3527, 7.8445, 6.6025, 5.5786, 7.0985,6.1369, 3.4259, 1.9397, 4.6157, 4.8105, 3.1768]).to(torch_device) # fmt: skip
|
||||
self.assertTrue(torch.allclose(output[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_model_tiny_random_stablelm_2_generation(self):
|
||||
# Check parallel residual and qk layernorm generation
|
||||
tokenizer = AutoTokenizer.from_pretrained("stabilityai/tiny-random-stablelm-2")
|
||||
model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2")
|
||||
input_ids = tokenizer.encode(
|
||||
"My favorite ride at the amusement park",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TEXT_COMPLETION = """My favorite ride at the amusement park is the 2000-mile roller coaster. It's a thrilling ride filled with roller coast"""
|
||||
self.assertEqual(text, EXPECTED_TEXT_COMPLETION)
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
|
Loading…
Reference in New Issue
Block a user