BLTForCausalLM

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-07-03 13:25:37 +00:00
parent bbe1bdb599
commit 3772dc7646
6 changed files with 130 additions and 1327 deletions

View File

@ -10,6 +10,7 @@ from safetensors.torch import load_file, save_file
from transformers.models.blt_wip.configuration_blt import BLTConfig
from transformers.models.blt_wip.modeling_blt import BLTModel
from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
from transformers.utils import logging as transformers_logging
@ -156,6 +157,8 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str
"global_config": global_config,
}
main_config_dict["tie_word_embeddings"] = False
logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
return main_config_dict
@ -203,8 +206,6 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor
elif "state_dict" in entropy_weights:
entropy_weights = entropy_weights["state_dict"]
logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors")
unified_weights = main_weights.copy()
for key, tensor in entropy_weights.items():
@ -213,6 +214,22 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor
unified_weights = apply_weight_mapping(unified_weights)
decoder_lm_head_key = "local_decoder.lm_head.weight"
top_lm_head_key = "lm_head.weight"
unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
del unified_weights[decoder_lm_head_key]
prefixed_weights = {}
for key, tensor in unified_weights.items():
if key == top_lm_head_key:
prefixed_weights[key] = tensor
elif not key.startswith("model."):
prefixed_weights[f"model.{key}"] = tensor
else:
prefixed_weights[key] = tensor
unified_weights = prefixed_weights
return unified_weights
@ -233,8 +250,6 @@ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
with open(tokenizer_path, "w") as f:
json.dump(tokenizer_config, f, indent=2)
logger.info(f"Tokenizer config saved to {tokenizer_path}")
def push_to_hub(
local_dir: str,
@ -344,7 +359,7 @@ def main():
parser.add_argument(
"--push_to_hub",
type=str,
default="itazap/blt-1b-converted",
default=None,
)
parser.add_argument(
"--hub_private",

View File

@ -23,19 +23,8 @@ torch.cuda.empty_cache()
def main(prompt: str = "my name is", model_name: str = "blt-1b"):
device = "cuda"
blt_model = BLTModel.from_pretrained("itazap/blt-1b-converted")
causal_lm = BLTForCausalLM(blt_model.config)
causal_lm.model.load_state_dict(blt_model.state_dict(), strict=False)
causal_lm.lm_head.weight = blt_model.local_decoder.lm_head.weight
causal_lm.save_pretrained( "./blt-1b-causallm")
# TRUE causal_lm.lm_head.weight == blt_model.local_decoder.lm_head.weight
model = BLTForCausalLM.from_pretrained("./blt-1b-causallm").to(device)
model = BLTForCausalLM.from_pretrained("itazap/blt-1b").to(device)
# FALSE model.lm_head.weight != blt_model.local_decoder.lm_head.weight
tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True)
input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device)

View File

@ -321,23 +321,15 @@ class BLTConfig(PretrainedConfig):
encoder_config=None,
decoder_config=None,
global_config=None,
# Generation configuration
bos_token_id=1,
eos_token_id=2,
pad_token_id=-1,
tie_word_embeddings=False,
**kwargs,
):
# Basic model configuration
self.tie_word_embeddings = tie_word_embeddings
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
# Generation configuration
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.return_dict = True
# Patching configuration
self.patch_in_forward = patch_in_forward
self.patch_size = patch_size
@ -387,7 +379,7 @@ class BLTConfig(PretrainedConfig):
elif isinstance(global_config, BLTGlobalTransformerConfig):
self.global_config = global_config
super().__init__(**kwargs)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = [
"BLTConfig",

View File

@ -666,14 +666,11 @@ class BLTLocalDecoder(nn.Module):
BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
)
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
)
z = 5
z = 5+1
# self.lm_head = nn.Linear(
# config.hidden_size,
# config.vocab_size,
# bias=False,
# )
def forward(
@ -718,7 +715,7 @@ class BLTLocalDecoder(nn.Module):
hidden_states = layer_outputs[0]
logits = self.norm(hidden_states)
logits = self.lm_head(logits)
# logits = self.lm_head(logits)
return logits, cache
@ -914,24 +911,24 @@ class BLTPreTrainedModel(PreTrainedModel):
emb_std = module.config.hidden_size ** (-0.5)
module.embed_tokens._custom_std = emb_std
module.lm_head._custom_std = emb_std
elif isinstance(module, BLTForCausalLM):
if module.lm_head is not None:
module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5)
class BLTModel(BLTPreTrainedModel):
def __init__(self, config: BLTConfig):
super().__init__(config)
self.config = config
self.local_encoder = BLTLocalEncoder(config.encoder_config)
self.global_transformer = BLTGlobalTransformer(config.global_config)
self.local_decoder = BLTLocalDecoder(config.decoder_config)
self.encoder_hash_tok_embedding = init_hash_embeddings(
config,
local_encoder_dim=config.encoder_config.hidden_size,
encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
)
if self.config.patch_in_forward:
self.patcher = BLTPatcher(config.patcher_config)
self.patcher.eval()
@ -940,9 +937,30 @@ class BLTModel(BLTPreTrainedModel):
else:
self.patcher = None
def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None):
def forward(
self,
tokens: torch.Tensor,
patch_lengths: Optional[torch.Tensor] = None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
**kwargs,
):
"""
Args:
tokens (torch.Tensor): Input token ids.
patch_lengths (Optional[torch.Tensor]): Patch lengths for patching.
attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility.
Returns:
torch.Tensor: Final hidden states (as before).
"""
batch_size, sequence_length = tokens.shape
# Handle patching
if patch_lengths is None:
if self.config.patching_mode == PatchingModeEnum.entropy:
@ -955,24 +973,20 @@ class BLTModel(BLTPreTrainedModel):
device=tokens.device,
)
else:
# Default to byte-level patching
patch_lengths = process_patch_lengths(
torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
self.config.max_patch_length
)
patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
)
encoder_embeds = compute_hash_embeddings(
tokens, self.local_encoder, self.encoder_hash_tok_embedding,
self.config.encoder_hash_byte_group_nb_functions,
self.config.encoder_hash_byte_group_size,
self.config.encoder_hash_byte_group_vocab,
)
encoder_hidden_states, encoder_cross_states = self.local_encoder(
input_ids=tokens,
input_embeds=encoder_embeds,
@ -982,18 +996,14 @@ class BLTModel(BLTPreTrainedModel):
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
global_hidden_states, _ = self.global_transformer(
input_embeds=global_hidden_states,
)
decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
)
output, _ = self.local_decoder(
tokens=tokens,
embeds=encoder_hidden_states,
@ -1002,7 +1012,11 @@ class BLTModel(BLTPreTrainedModel):
cross_mask=cross_attn_mask_dec,
full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
)
if output_hidden_states or output_attentions:
if return_dict:
return {"last_hidden_state": output, "hidden_states": None, "attentions": None}
else:
return (output, None, None)
return output
def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
@ -1167,10 +1181,35 @@ class BLTPatcher(BLTPreTrainedModel):
class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
config_class = BLTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
def __init__(self, config):
super().__init__(config)
self.model = BLTModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.local_encoder.embed_tokens
def set_input_embeddings(self, value):
self.model.local_encoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
@ -1183,35 +1222,59 @@ class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
**kwargs,
):
logits = self.model(input_ids)
"""
Args:
input_ids (torch.LongTensor): Input token ids.
attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments.
labels (torch.LongTensor, optional): Labels for language modeling loss.
Returns:
CausalLMOutputWithPast or tuple: Standard transformers output.
"""
# Route only input_ids to BLTModel (as tokens)
hidden_states = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if isinstance(hidden_states, dict):
sequence_output = hidden_states["last_hidden_state"]
elif isinstance(hidden_states, tuple):
sequence_output = hidden_states[0]
else:
sequence_output = hidden_states
logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,)
if loss is not None:
output = (loss,) + output
return output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values}
def get_input_embeddings(self):
return self.model.local_encoder.embed_tokens
def set_input_embeddings(self, value):
self.model.local_encoder.embed_tokens = value
def get_output_embeddings(self):
return self.model.local_decoder.lm_head
def set_output_embeddings(self, new_embeddings):
self.model.local_decoder.lm_head = new_embeddings
__all__ = [
"BLTPreTrainedModel",
"BLTModel",

File diff suppressed because it is too large Load Diff

View File

@ -14,8 +14,6 @@
# limitations under the License.
"""Tokenization classes for BLT."""
import os
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
@ -270,17 +268,4 @@ class BLTTokenizer(PreTrainedTokenizer):
"""Get vocab size like the original tokenizer."""
return self.vocab_size_unit_1 + self.offsetting_special_char
def __call__(self, text, **kwargs):
"""Override the default __call__ method to properly handle BOS/EOS tokens."""
# Use our custom encode method to ensure consistent behavior
if isinstance(text, str):
tokens = self.encode(text, add_bos=self.add_bos_token, add_eos=self.add_eos_token)
return {"input_ids": torch.tensor([tokens])}
elif isinstance(text, list):
tokens_list = [self.encode(t, add_bos=self.add_bos_token, add_eos=self.add_eos_token) for t in text]
return {"input_ids": torch.tensor(tokens_list)}
else:
# Fallback to parent implementation
return super().__call__(text, **kwargs)
__all__ = ["BLTTokenizer"]