Llama Guard updates (#37872)

* Unhardcode use_chunked_attention, fix no_rope_layers

* Go back to exhaustive list of bools

* Conversion and modeling updates

* Fix rope

* Unhardcode rope

* Fix context length

* style

* Minor updates to conversion

* Use StaticCache

* Minor simplification

* DynamicCache 🤦

* Style

* Style
This commit is contained in:
Pedro Cuenca 2025-04-30 10:34:43 +02:00 committed by GitHub
parent 34f26e2c3e
commit 63cd4c76f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 49 deletions

View File

@ -224,8 +224,13 @@ class Llama4TextConfig(PretrainedConfig):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
<TODO>
<TODO>
no_rope_layers (`int`, *optional*): TODO
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
no_rope_layers (`List[int]`, *optional*):
List with at least the same length as the number of layers in the model.
A `1` at an index position indicates that the corresponding layer will use RoPE,
while a `0` indicates that it's a NoPE layer.
no_rope_layer_interval (`int`, *optional*, defaults to 4):
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
`no_rope_layer_interval` layers.
attention_chunk_size (`int`, *optional*, defaults to 8192):
<TODO>
attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
@ -339,11 +344,15 @@ class Llama4TextConfig(PretrainedConfig):
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
# Backwards compatibility
if no_rope_layers == []:
no_rope_layers = None
default_no_rope_layers = [
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
]
# no_rope_layers == [] is invalid as we cannot have 0 layers
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
self.interleave_moe_layer_step = interleave_moe_layer_step

View File

@ -65,6 +65,7 @@ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency?
# r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
# Vision encoder mapping
@ -166,8 +167,8 @@ def get_concat_dim(key):
return 0
def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3):
hidden_dim = 4 * int(2 * hidden_dim / 3)
def compute_intermediate_size(hidden_dim, ffn_exp=4, multiple_of=1024, ffn_dim_multiplier=1.2):
hidden_dim = ffn_exp * int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
@ -203,6 +204,8 @@ def max_context_length(model_path, instruct=False):
with open(os.path.join(model_path, "params.json"), "r") as f:
params = json.load(f)
params = params.get("model", params)
if params.get("moe_args") is None:
return 8192
num_experts = params["moe_args"]["num_experts"]
return 10485760 if num_experts == 16 else 1048576
@ -242,24 +245,40 @@ def write_model(
# some constants from original code
rope_scaling = {
"rope_type": "llama3",
"factor": 8.0,
"factor": params.get("rope_scaling_factor", 8.0),
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"high_freq_factor": params.get("rope_high_freq_factor", 4.0),
"original_max_position_embeddings": 8192,
}
config_kwargs.update({"rope_scaling": rope_scaling})
if attention_chunk_size is None:
config_kwargs.update({"cache_implementation": "static"})
# compute additional params for weight conversion
num_heads_per_shard = num_heads // num_shards
dim_per_head = dim // num_heads
# intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
intermediate_size_mlp = compute_intermediate_size(
dim,
ffn_exp=params["ffn_exp"],
multiple_of=params["multiple_of"],
ffn_dim_multiplier=params["ffn_dim_multiplier"],
)
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_experts = params["moe_args"]["num_experts"]
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
if hasattr(params, "moe_args"):
num_experts = params["moe_args"]["num_experts"]
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
else:
# Dense model (possibly Llama Guard) - disable all moe layers
num_experts = 0
interleave_moe_layer_step = 0
config_kwargs.update({"moe_layers": []})
# Ensure all layers are rope if `nope_layer_interval` is None
no_rope_layer_interval = params["nope_layer_interval"]
no_rope_layer_interval = num_heads * 2 if no_rope_layer_interval is None else no_rope_layer_interval
bos_token_id = 200000
eos_token_id = [200001, 200007, 200008] if instruct else 200001
@ -273,7 +292,7 @@ def write_model(
rope_theta=rope_theta,
num_hidden_layers=num_layers,
intermediate_size=8192,
intermediate_size_mlp=16384,
intermediate_size_mlp=intermediate_size_mlp,
max_position_embeddings=max_context_length(input_base_path, instruct),
num_local_experts=num_experts,
interleave_moe_layer_step=interleave_moe_layer_step,
@ -336,7 +355,7 @@ def write_model(
sharded_keys = []
for _key in all_keys_raw:
try:
if (loaded[0][_key] == loaded[1][_key]).all():
if num_shards == 1 or (loaded[0][_key] == loaded[1][_key]).all():
repeated_keys.append(_key)
else:
sharded_keys.append(_key)
@ -354,7 +373,7 @@ def write_model(
for key in tqdm(all_keys, desc="Renaming and processing all keys", unit="key"):
new_key = new_keys[key]
print(key, new_key)
if not is_param_same_across_shards(new_key):
if num_shards > 1 and not is_param_same_across_shards(new_key):
current_parameter = [chunk.pop(key) for chunk in loaded if not isinstance(chunk[key], io.BytesIO)]
else:
print(f"{key} (now {new_key}) is the same across all shards.")
@ -565,8 +584,8 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
"text_post_train", 61, 8
) # <|text_post_train_reserved_special_token_8|>, ..., <|text_post_train_reserved_special_token_68|>
# 200080, ..., 201133
LLAMA4_VISION_SPECIAL_TOKENS = [
@ -621,15 +640,6 @@ class Llama4Converter(TikTokenConverter):
**kwargs,
)
# to check
# import tiktoken
# model = tiktoken.Encoding(
# name=Path(model_path).name,
# pat_str=self.O200K_PATTERN,
# mergeable_ranks=mergeable_ranks,
# special_tokens=self.special_tokens,
# )
instruct = chat_template is not None
self.update_post_processor(self.converted_tokenizer)
# finer special_tokens_map.json
@ -687,12 +697,10 @@ if __name__ == "__main__":
parser.add_argument(
"--input_dir",
type=str,
default="/fsx/arthur/Llama-4-17B-Omni-Instruct-Original",
help="Location of the local folder copied from the Hub.",
)
parser.add_argument(
"--output_dir",
default="llama4_hf_vision",
type=str,
help="Location to write HF model and tokenizer",
)

View File

@ -20,12 +20,11 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridChunkedCache
from ...cache_utils import Cache, DynamicCache, HybridChunkedCache
from ...generation import GenerationMixin
from ...integrations.hub_kernels import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
@ -287,7 +286,7 @@ class Llama4TextAttention(nn.Module):
self.attn_temperature_tuning = config.attn_temperature_tuning
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
self.use_rope = config.no_rope_layers[layer_idx]
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
@ -374,7 +373,7 @@ class Llama4TextDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Llama4TextAttention(config, layer_idx)
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx])
self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse
self.feed_forward = Llama4TextMoe(config)
@ -643,7 +642,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None:
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
if self.config.get_text_config().get("attention_chunk_size") is not None:
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
else:
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -740,6 +742,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
sequence_length = input_tensor.shape[1]
cache_position = cache_position.to(self.device)
attention_chunk_size = self.config.attention_chunk_size
using_chunked_attention = attention_chunk_size is not None
first_cache_position = cache_position[0]
@ -748,26 +751,28 @@ class Llama4TextModel(Llama4PreTrainedModel):
else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size
)
key_length = (
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
if using_chunked_attention:
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size
)
key_length = (
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
)
if use_cache
else full_cache_length
)
if use_cache
else full_cache_length
)
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
chunked_attention_mask = make_flex_block_causal_mask(
attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
)
if using_chunked_attention:
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
chunked_attention_mask = make_flex_block_causal_mask(
attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets
)
attention_mask = make_flex_block_causal_mask(
attention_mask,
query_length=sequence_length,
@ -780,15 +785,16 @@ class Llama4TextModel(Llama4PreTrainedModel):
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
dtype, device = input_tensor.dtype, input_tensor.device
target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=max(full_cache_length, attention_chunk_size),
target_length=target_length,
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if full_cache_length > self.config.attention_chunk_size:
if using_chunked_attention and full_cache_length > attention_chunk_size:
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
end_idx = start_idx + key_length
chunked_attention_mask = self.create_chunked_attention_mask(