mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Llama Guard updates (#37872)
* Unhardcode use_chunked_attention, fix no_rope_layers
* Go back to exhaustive list of bools
* Conversion and modeling updates
* Fix rope
* Unhardcode rope
* Fix context length
* style
* Minor updates to conversion
* Use StaticCache
* Minor simplification
* DynamicCache 🤦
* Style
* Style
This commit is contained in:
parent
34f26e2c3e
commit
63cd4c76f3
@ -224,8 +224,13 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
<TODO>
|
||||
<TODO>
|
||||
no_rope_layers (`int`, *optional*): TODO
|
||||
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
|
||||
no_rope_layers (`List[int]`, *optional*):
|
||||
List with at least the same length as the number of layers in the model.
|
||||
A `1` at an index position indicates that the corresponding layer will use RoPE,
|
||||
while a `0` indicates that it's a NoPE layer.
|
||||
no_rope_layer_interval (`int`, *optional*, defaults to 4):
|
||||
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
|
||||
`no_rope_layer_interval` layers.
|
||||
attention_chunk_size (`int`, *optional*, defaults to 8192):
|
||||
<TODO>
|
||||
attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
|
||||
@ -339,11 +344,15 @@ class Llama4TextConfig(PretrainedConfig):
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.router_jitter_noise = router_jitter_noise
|
||||
|
||||
# Backwards compatibility
|
||||
if no_rope_layers == []:
|
||||
no_rope_layers = None
|
||||
|
||||
default_no_rope_layers = [
|
||||
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
# no_rope_layers == [] is invalid as we cannot have 0 layers
|
||||
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
|
||||
|
||||
self.interleave_moe_layer_step = interleave_moe_layer_step
|
||||
|
@ -65,6 +65,7 @@ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency?
|
||||
# r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
|
||||
r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
|
||||
r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
|
||||
r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
|
||||
|
||||
# Vision encoder mapping
|
||||
@ -166,8 +167,8 @@ def get_concat_dim(key):
|
||||
return 0
|
||||
|
||||
|
||||
def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3):
|
||||
hidden_dim = 4 * int(2 * hidden_dim / 3)
|
||||
def compute_intermediate_size(hidden_dim, ffn_exp=4, multiple_of=1024, ffn_dim_multiplier=1.2):
|
||||
hidden_dim = ffn_exp * int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
@ -203,6 +204,8 @@ def max_context_length(model_path, instruct=False):
|
||||
with open(os.path.join(model_path, "params.json"), "r") as f:
|
||||
params = json.load(f)
|
||||
params = params.get("model", params)
|
||||
if params.get("moe_args") is None:
|
||||
return 8192
|
||||
num_experts = params["moe_args"]["num_experts"]
|
||||
return 10485760 if num_experts == 16 else 1048576
|
||||
|
||||
@ -242,24 +245,40 @@ def write_model(
|
||||
# some constants from original code
|
||||
rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": 8.0,
|
||||
"factor": params.get("rope_scaling_factor", 8.0),
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"high_freq_factor": params.get("rope_high_freq_factor", 4.0),
|
||||
"original_max_position_embeddings": 8192,
|
||||
}
|
||||
config_kwargs.update({"rope_scaling": rope_scaling})
|
||||
|
||||
if attention_chunk_size is None:
|
||||
config_kwargs.update({"cache_implementation": "static"})
|
||||
|
||||
# compute additional params for weight conversion
|
||||
num_heads_per_shard = num_heads // num_shards
|
||||
dim_per_head = dim // num_heads
|
||||
# intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
|
||||
intermediate_size_mlp = compute_intermediate_size(
|
||||
dim,
|
||||
ffn_exp=params["ffn_exp"],
|
||||
multiple_of=params["multiple_of"],
|
||||
ffn_dim_multiplier=params["ffn_dim_multiplier"],
|
||||
)
|
||||
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
|
||||
num_experts = params["moe_args"]["num_experts"]
|
||||
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
|
||||
if hasattr(params, "moe_args"):
|
||||
num_experts = params["moe_args"]["num_experts"]
|
||||
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
|
||||
else:
|
||||
# Dense model (possibly Llama Guard) - disable all moe layers
|
||||
num_experts = 0
|
||||
interleave_moe_layer_step = 0
|
||||
config_kwargs.update({"moe_layers": []})
|
||||
|
||||
# Ensure all layers are rope if `nope_layer_interval` is None
|
||||
no_rope_layer_interval = params["nope_layer_interval"]
|
||||
no_rope_layer_interval = num_heads * 2 if no_rope_layer_interval is None else no_rope_layer_interval
|
||||
|
||||
bos_token_id = 200000
|
||||
eos_token_id = [200001, 200007, 200008] if instruct else 200001
|
||||
@ -273,7 +292,7 @@ def write_model(
|
||||
rope_theta=rope_theta,
|
||||
num_hidden_layers=num_layers,
|
||||
intermediate_size=8192,
|
||||
intermediate_size_mlp=16384,
|
||||
intermediate_size_mlp=intermediate_size_mlp,
|
||||
max_position_embeddings=max_context_length(input_base_path, instruct),
|
||||
num_local_experts=num_experts,
|
||||
interleave_moe_layer_step=interleave_moe_layer_step,
|
||||
@ -336,7 +355,7 @@ def write_model(
|
||||
sharded_keys = []
|
||||
for _key in all_keys_raw:
|
||||
try:
|
||||
if (loaded[0][_key] == loaded[1][_key]).all():
|
||||
if num_shards == 1 or (loaded[0][_key] == loaded[1][_key]).all():
|
||||
repeated_keys.append(_key)
|
||||
else:
|
||||
sharded_keys.append(_key)
|
||||
@ -354,7 +373,7 @@ def write_model(
|
||||
for key in tqdm(all_keys, desc="Renaming and processing all keys", unit="key"):
|
||||
new_key = new_keys[key]
|
||||
print(key, new_key)
|
||||
if not is_param_same_across_shards(new_key):
|
||||
if num_shards > 1 and not is_param_same_across_shards(new_key):
|
||||
current_parameter = [chunk.pop(key) for chunk in loaded if not isinstance(chunk[key], io.BytesIO)]
|
||||
else:
|
||||
print(f"{key} (now {new_key}) is the same across all shards.")
|
||||
@ -565,8 +584,8 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
"text_post_train", 61, 8
|
||||
) # <|text_post_train_reserved_special_token_8|>, ..., <|text_post_train_reserved_special_token_68|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
LLAMA4_VISION_SPECIAL_TOKENS = [
|
||||
@ -621,15 +640,6 @@ class Llama4Converter(TikTokenConverter):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# to check
|
||||
# import tiktoken
|
||||
# model = tiktoken.Encoding(
|
||||
# name=Path(model_path).name,
|
||||
# pat_str=self.O200K_PATTERN,
|
||||
# mergeable_ranks=mergeable_ranks,
|
||||
# special_tokens=self.special_tokens,
|
||||
# )
|
||||
|
||||
instruct = chat_template is not None
|
||||
self.update_post_processor(self.converted_tokenizer)
|
||||
# finer special_tokens_map.json
|
||||
@ -687,12 +697,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="/fsx/arthur/Llama-4-17B-Omni-Instruct-Original",
|
||||
help="Location of the local folder copied from the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="llama4_hf_vision",
|
||||
type=str,
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
|
@ -20,12 +20,11 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridChunkedCache
|
||||
from ...cache_utils import Cache, DynamicCache, HybridChunkedCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.hub_kernels import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
@ -287,7 +286,7 @@ class Llama4TextAttention(nn.Module):
|
||||
self.attn_temperature_tuning = config.attn_temperature_tuning
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
|
||||
self.use_rope = config.no_rope_layers[layer_idx]
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
@ -374,7 +373,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Llama4TextAttention(config, layer_idx)
|
||||
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
|
||||
self.use_chunked_attention = config.attention_chunk_size is not None and bool(config.no_rope_layers[layer_idx])
|
||||
self.is_moe_layer = layer_idx in config.moe_layers
|
||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||
self.feed_forward = Llama4TextMoe(config)
|
||||
@ -643,7 +642,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
if self.config.get_text_config().get("attention_chunk_size") is not None:
|
||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
else:
|
||||
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -740,6 +742,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
sequence_length = input_tensor.shape[1]
|
||||
cache_position = cache_position.to(self.device)
|
||||
attention_chunk_size = self.config.attention_chunk_size
|
||||
using_chunked_attention = attention_chunk_size is not None
|
||||
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
@ -748,26 +751,28 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
else:
|
||||
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
||||
|
||||
cond1 = first_cache_position >= attention_chunk_size
|
||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||
first_cache_position + sequence_length > attention_chunk_size
|
||||
)
|
||||
key_length = (
|
||||
torch.where(
|
||||
cond1,
|
||||
attention_chunk_size + sequence_length - 1,
|
||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||
if using_chunked_attention:
|
||||
cond1 = first_cache_position >= attention_chunk_size
|
||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||
first_cache_position + sequence_length > attention_chunk_size
|
||||
)
|
||||
key_length = (
|
||||
torch.where(
|
||||
cond1,
|
||||
attention_chunk_size + sequence_length - 1,
|
||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||
)
|
||||
if use_cache
|
||||
else full_cache_length
|
||||
)
|
||||
if use_cache
|
||||
else full_cache_length
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
|
||||
chunked_attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
|
||||
)
|
||||
if using_chunked_attention:
|
||||
offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0))
|
||||
chunked_attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets
|
||||
)
|
||||
attention_mask = make_flex_block_causal_mask(
|
||||
attention_mask,
|
||||
query_length=sequence_length,
|
||||
@ -780,15 +785,16 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
target_length = max(full_cache_length, attention_chunk_size) if using_chunked_attention else full_cache_length
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=max(full_cache_length, attention_chunk_size),
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
if full_cache_length > self.config.attention_chunk_size:
|
||||
if using_chunked_attention and full_cache_length > attention_chunk_size:
|
||||
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
||||
end_idx = start_idx + key_length
|
||||
chunked_attention_mask = self.create_chunked_attention_mask(
|
||||
|
Loading…
Reference in New Issue
Block a user