mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
BLTForCausalLM
This commit is contained in:
parent
bbe1bdb599
commit
3772dc7646
@ -10,6 +10,7 @@ from safetensors.torch import load_file, save_file
|
||||
|
||||
from transformers.models.blt_wip.configuration_blt import BLTConfig
|
||||
from transformers.models.blt_wip.modeling_blt import BLTModel
|
||||
from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
@ -156,6 +157,8 @@ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str
|
||||
"global_config": global_config,
|
||||
}
|
||||
|
||||
main_config_dict["tie_word_embeddings"] = False
|
||||
|
||||
logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
|
||||
return main_config_dict
|
||||
|
||||
@ -203,8 +206,6 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor
|
||||
elif "state_dict" in entropy_weights:
|
||||
entropy_weights = entropy_weights["state_dict"]
|
||||
|
||||
logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors")
|
||||
|
||||
unified_weights = main_weights.copy()
|
||||
|
||||
for key, tensor in entropy_weights.items():
|
||||
@ -213,6 +214,22 @@ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, tor
|
||||
|
||||
unified_weights = apply_weight_mapping(unified_weights)
|
||||
|
||||
decoder_lm_head_key = "local_decoder.lm_head.weight"
|
||||
top_lm_head_key = "lm_head.weight"
|
||||
unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
|
||||
del unified_weights[decoder_lm_head_key]
|
||||
|
||||
prefixed_weights = {}
|
||||
for key, tensor in unified_weights.items():
|
||||
if key == top_lm_head_key:
|
||||
prefixed_weights[key] = tensor
|
||||
elif not key.startswith("model."):
|
||||
prefixed_weights[f"model.{key}"] = tensor
|
||||
else:
|
||||
prefixed_weights[key] = tensor
|
||||
|
||||
unified_weights = prefixed_weights
|
||||
|
||||
return unified_weights
|
||||
|
||||
|
||||
@ -233,8 +250,6 @@ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
|
||||
with open(tokenizer_path, "w") as f:
|
||||
json.dump(tokenizer_config, f, indent=2)
|
||||
|
||||
logger.info(f"Tokenizer config saved to {tokenizer_path}")
|
||||
|
||||
|
||||
def push_to_hub(
|
||||
local_dir: str,
|
||||
@ -344,7 +359,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
type=str,
|
||||
default="itazap/blt-1b-converted",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_private",
|
||||
|
@ -23,19 +23,8 @@ torch.cuda.empty_cache()
|
||||
def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
||||
device = "cuda"
|
||||
|
||||
blt_model = BLTModel.from_pretrained("itazap/blt-1b-converted")
|
||||
|
||||
causal_lm = BLTForCausalLM(blt_model.config)
|
||||
causal_lm.model.load_state_dict(blt_model.state_dict(), strict=False)
|
||||
causal_lm.lm_head.weight = blt_model.local_decoder.lm_head.weight
|
||||
causal_lm.save_pretrained( "./blt-1b-causallm")
|
||||
|
||||
# TRUE causal_lm.lm_head.weight == blt_model.local_decoder.lm_head.weight
|
||||
|
||||
model = BLTForCausalLM.from_pretrained("./blt-1b-causallm").to(device)
|
||||
model = BLTForCausalLM.from_pretrained("itazap/blt-1b").to(device)
|
||||
|
||||
# FALSE model.lm_head.weight != blt_model.local_decoder.lm_head.weight
|
||||
|
||||
tokenizer = BLTTokenizer(add_bos_token=True, add_eos_token=True)
|
||||
|
||||
input_ids = torch.tensor([tokenizer.encode(prompt, add_eos=False)]).to(device)
|
||||
|
@ -321,23 +321,15 @@ class BLTConfig(PretrainedConfig):
|
||||
encoder_config=None,
|
||||
decoder_config=None,
|
||||
global_config=None,
|
||||
# Generation configuration
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pad_token_id=-1,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Basic model configuration
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Generation configuration
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.return_dict = True
|
||||
|
||||
# Patching configuration
|
||||
self.patch_in_forward = patch_in_forward
|
||||
self.patch_size = patch_size
|
||||
@ -387,7 +379,7 @@ class BLTConfig(PretrainedConfig):
|
||||
elif isinstance(global_config, BLTGlobalTransformerConfig):
|
||||
self.global_config = global_config
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
__all__ = [
|
||||
"BLTConfig",
|
||||
|
@ -666,14 +666,11 @@ class BLTLocalDecoder(nn.Module):
|
||||
BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
|
||||
)
|
||||
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
z = 5
|
||||
z = 5+1
|
||||
# self.lm_head = nn.Linear(
|
||||
# config.hidden_size,
|
||||
# config.vocab_size,
|
||||
# bias=False,
|
||||
# )
|
||||
|
||||
|
||||
def forward(
|
||||
@ -718,7 +715,7 @@ class BLTLocalDecoder(nn.Module):
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
logits = self.norm(hidden_states)
|
||||
logits = self.lm_head(logits)
|
||||
# logits = self.lm_head(logits)
|
||||
return logits, cache
|
||||
|
||||
|
||||
@ -914,24 +911,24 @@ class BLTPreTrainedModel(PreTrainedModel):
|
||||
emb_std = module.config.hidden_size ** (-0.5)
|
||||
module.embed_tokens._custom_std = emb_std
|
||||
module.lm_head._custom_std = emb_std
|
||||
|
||||
elif isinstance(module, BLTForCausalLM):
|
||||
if module.lm_head is not None:
|
||||
module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5)
|
||||
|
||||
|
||||
class BLTModel(BLTPreTrainedModel):
|
||||
def __init__(self, config: BLTConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.local_encoder = BLTLocalEncoder(config.encoder_config)
|
||||
self.global_transformer = BLTGlobalTransformer(config.global_config)
|
||||
self.local_decoder = BLTLocalDecoder(config.decoder_config)
|
||||
|
||||
self.encoder_hash_tok_embedding = init_hash_embeddings(
|
||||
config,
|
||||
local_encoder_dim=config.encoder_config.hidden_size,
|
||||
encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
|
||||
)
|
||||
|
||||
if self.config.patch_in_forward:
|
||||
self.patcher = BLTPatcher(config.patcher_config)
|
||||
self.patcher.eval()
|
||||
@ -940,9 +937,30 @@ class BLTModel(BLTPreTrainedModel):
|
||||
else:
|
||||
self.patcher = None
|
||||
|
||||
def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None):
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
patch_lengths: Optional[torch.Tensor] = None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tokens (torch.Tensor): Input token ids.
|
||||
patch_lengths (Optional[torch.Tensor]): Patch lengths for patching.
|
||||
attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility.
|
||||
Returns:
|
||||
torch.Tensor: Final hidden states (as before).
|
||||
"""
|
||||
batch_size, sequence_length = tokens.shape
|
||||
|
||||
# Handle patching
|
||||
if patch_lengths is None:
|
||||
if self.config.patching_mode == PatchingModeEnum.entropy:
|
||||
@ -955,24 +973,20 @@ class BLTModel(BLTPreTrainedModel):
|
||||
device=tokens.device,
|
||||
)
|
||||
else:
|
||||
# Default to byte-level patching
|
||||
patch_lengths = process_patch_lengths(
|
||||
torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
|
||||
self.config.max_patch_length
|
||||
)
|
||||
|
||||
patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
|
||||
cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
|
||||
patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
|
||||
)
|
||||
|
||||
encoder_embeds = compute_hash_embeddings(
|
||||
tokens, self.local_encoder, self.encoder_hash_tok_embedding,
|
||||
self.config.encoder_hash_byte_group_nb_functions,
|
||||
self.config.encoder_hash_byte_group_size,
|
||||
self.config.encoder_hash_byte_group_vocab,
|
||||
)
|
||||
|
||||
encoder_hidden_states, encoder_cross_states = self.local_encoder(
|
||||
input_ids=tokens,
|
||||
input_embeds=encoder_embeds,
|
||||
@ -982,18 +996,14 @@ class BLTModel(BLTPreTrainedModel):
|
||||
num_patches=patch_lengths.shape[1],
|
||||
patch_ids=patch_ids,
|
||||
)
|
||||
|
||||
global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
|
||||
|
||||
global_hidden_states, _ = self.global_transformer(
|
||||
input_embeds=global_hidden_states,
|
||||
)
|
||||
|
||||
decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
|
||||
cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
|
||||
decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
|
||||
)
|
||||
|
||||
output, _ = self.local_decoder(
|
||||
tokens=tokens,
|
||||
embeds=encoder_hidden_states,
|
||||
@ -1002,7 +1012,11 @@ class BLTModel(BLTPreTrainedModel):
|
||||
cross_mask=cross_attn_mask_dec,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
|
||||
)
|
||||
|
||||
if output_hidden_states or output_attentions:
|
||||
if return_dict:
|
||||
return {"last_hidden_state": output, "hidden_states": None, "attentions": None}
|
||||
else:
|
||||
return (output, None, None)
|
||||
return output
|
||||
|
||||
def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
|
||||
@ -1167,10 +1181,35 @@ class BLTPatcher(BLTPreTrainedModel):
|
||||
|
||||
|
||||
class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
|
||||
config_class = BLTConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = BLTModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.local_encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.local_encoder.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1183,35 +1222,59 @@ class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
logits = self.model(input_ids)
|
||||
|
||||
"""
|
||||
Args:
|
||||
input_ids (torch.LongTensor): Input token ids.
|
||||
attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments.
|
||||
labels (torch.LongTensor, optional): Labels for language modeling loss.
|
||||
Returns:
|
||||
CausalLMOutputWithPast or tuple: Standard transformers output.
|
||||
"""
|
||||
# Route only input_ids to BLTModel (as tokens)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(hidden_states, dict):
|
||||
sequence_output = hidden_states["last_hidden_state"]
|
||||
elif isinstance(hidden_states, tuple):
|
||||
sequence_output = hidden_states[0]
|
||||
else:
|
||||
sequence_output = hidden_states
|
||||
logits = self.lm_head(sequence_output)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
if not return_dict:
|
||||
output = (logits,)
|
||||
if loss is not None:
|
||||
output = (loss,) + output
|
||||
return output
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.local_encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.local_encoder.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.local_decoder.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.model.local_decoder.lm_head = new_embeddings
|
||||
|
||||
__all__ = [
|
||||
"BLTPreTrainedModel",
|
||||
"BLTModel",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,8 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for BLT."""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
@ -270,17 +268,4 @@ class BLTTokenizer(PreTrainedTokenizer):
|
||||
"""Get vocab size like the original tokenizer."""
|
||||
return self.vocab_size_unit_1 + self.offsetting_special_char
|
||||
|
||||
def __call__(self, text, **kwargs):
|
||||
"""Override the default __call__ method to properly handle BOS/EOS tokens."""
|
||||
# Use our custom encode method to ensure consistent behavior
|
||||
if isinstance(text, str):
|
||||
tokens = self.encode(text, add_bos=self.add_bos_token, add_eos=self.add_eos_token)
|
||||
return {"input_ids": torch.tensor([tokens])}
|
||||
elif isinstance(text, list):
|
||||
tokens_list = [self.encode(t, add_bos=self.add_bos_token, add_eos=self.add_eos_token) for t in text]
|
||||
return {"input_ids": torch.tensor(tokens_list)}
|
||||
else:
|
||||
# Fallback to parent implementation
|
||||
return super().__call__(text, **kwargs)
|
||||
|
||||
__all__ = ["BLTTokenizer"]
|
Loading…
Reference in New Issue
Block a user