mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
adding files after add-new-model-like
This commit is contained in:
parent
3772dc7646
commit
1a50b29693
@ -395,6 +395,8 @@
|
||||
title: Blenderbot Small
|
||||
- local: model_doc/bloom
|
||||
title: BLOOM
|
||||
- local: model_doc/blt
|
||||
title: BLT
|
||||
- local: model_doc/bort
|
||||
title: BORT
|
||||
- local: model_doc/byt5
|
||||
|
102
docs/source/en/model_doc/blt.md
Normal file
102
docs/source/en/model_doc/blt.md
Normal file
@ -0,0 +1,102 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# BLT
|
||||
|
||||
# BLT
|
||||
|
||||
## Overview
|
||||
|
||||
The BLT model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*<INSERT PAPER ABSTRACT HERE>*
|
||||
|
||||
Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
|
||||
## BLTConfig
|
||||
|
||||
[[autodoc]] BLTConfig
|
||||
|
||||
## BLTTokenizer
|
||||
|
||||
[[autodoc]] BLTTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## BLTTokenizerFast
|
||||
|
||||
[[autodoc]] BLTTokenizerFast
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- update_post_processor
|
||||
- save_vocabulary
|
||||
|
||||
## BLTModel
|
||||
|
||||
[[autodoc]] BLTModel
|
||||
- forward
|
||||
|
||||
## BLTForCausalLM
|
||||
|
||||
[[autodoc]] BLTForCausalLM
|
||||
- forward
|
||||
|
||||
## BLTForSequenceClassification
|
||||
|
||||
[[autodoc]] BLTForSequenceClassification
|
||||
- forward
|
||||
|
||||
## BLTForQuestionAnswering
|
||||
|
||||
[[autodoc]] BLTForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## BLTForTokenClassification
|
||||
|
||||
[[autodoc]] BLTForTokenClassification
|
||||
- forward
|
||||
|
||||
## FlaxBLTModel
|
||||
|
||||
[[autodoc]] FlaxBLTModel
|
||||
- __call__
|
||||
|
||||
## FlaxBLTForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBLTForCausalLM
|
||||
- __call__
|
@ -200,6 +200,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("lightglue", "LightGlueConfig"),
|
||||
("lilt", "LiltConfig"),
|
||||
("llama", "LlamaConfig"),
|
||||
("blt", "BLTConfig"),
|
||||
("llama4", "Llama4Config"),
|
||||
("llama4_text", "Llama4TextConfig"),
|
||||
("llava", "LlavaConfig"),
|
||||
@ -586,6 +587,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("lightglue", "LightGlue"),
|
||||
("lilt", "LiLT"),
|
||||
("llama", "LLaMA"),
|
||||
("blt", "BLT"),
|
||||
("llama2", "Llama2"),
|
||||
("llama3", "Llama3"),
|
||||
("llama4", "Llama4"),
|
||||
|
@ -189,6 +189,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("lightglue", "LightGlueForKeypointMatching"),
|
||||
("lilt", "LiltModel"),
|
||||
("llama", "LlamaModel"),
|
||||
("blt", "BLTModel"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("llama4_text", "Llama4TextModel"),
|
||||
("llava", "LlavaModel"),
|
||||
@ -609,6 +610,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("blt", "BLTForCausalLM"),
|
||||
("llama4", "Llama4ForCausalLM"),
|
||||
("llama4_text", "Llama4ForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
|
0
src/transformers/models/blt/__init__.py
Normal file
0
src/transformers/models/blt/__init__.py
Normal file
397
src/transformers/models/blt/convert_blt_weights_to_hf.py
Normal file
397
src/transformers/models/blt/convert_blt_weights_to_hf.py
Normal file
@ -0,0 +1,397 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_folder
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from transformers.models.blt_wip.configuration_blt import BLTConfig
|
||||
from transformers.models.blt_wip.modeling_blt import BLTModel
|
||||
from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
logger = transformers_logging.get_logger(__name__)
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
|
||||
logger.info("Merging configurations")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
main_config = json.load(f)
|
||||
|
||||
with open(entropy_params_path, "r") as f:
|
||||
entropy_data = json.load(f)
|
||||
|
||||
entropy_model_params = entropy_data.get("entropy_model", {})
|
||||
patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
|
||||
|
||||
unified_config = main_config.copy()["args"]
|
||||
|
||||
for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
|
||||
if key in unified_config and not isinstance(unified_config[key], int):
|
||||
unified_config[key] = int(unified_config[key])
|
||||
|
||||
patch_size = patcher_args.get("patch_size", 8)
|
||||
if isinstance(patch_size, float):
|
||||
patch_size = int(patch_size)
|
||||
|
||||
# Create patcher config
|
||||
patcher_hidden_size = int(entropy_model_params.get("dim", 512))
|
||||
patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256))
|
||||
patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of)
|
||||
|
||||
patcher_config = {
|
||||
"vocab_size": int(entropy_model_params.get("vocab_size", 256)),
|
||||
"hidden_size": patcher_hidden_size,
|
||||
"num_hidden_layers": int(entropy_model_params.get("n_layers", 8)),
|
||||
"num_attention_heads": int(entropy_model_params.get("n_heads", 8)),
|
||||
"num_key_value_heads": int(entropy_model_params.get("n_kv_heads"))
|
||||
if entropy_model_params.get("n_kv_heads") is not None
|
||||
else None,
|
||||
"max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)),
|
||||
"norm_eps": entropy_model_params.get("norm_eps", 1e-5),
|
||||
"dropout": entropy_model_params.get("dropout", 0.0),
|
||||
"rope_theta": entropy_model_params.get("rope_theta", 10000.0),
|
||||
"attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
|
||||
"attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
|
||||
"intermediate_size": patcher_intermediate_size,
|
||||
}
|
||||
|
||||
# Create encoder config
|
||||
encoder_hidden_size = unified_config.get("dim_local_encoder", 1024)
|
||||
encoder_multiple_of = unified_config.get("multiple_of", 256)
|
||||
encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of)
|
||||
|
||||
encoder_config = {
|
||||
"vocab_size": unified_config.get("vocab_size", 256),
|
||||
"cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False),
|
||||
"cross_attn_k": unified_config.get("cross_attn_k", 2),
|
||||
"hidden_size_global": unified_config.get("hidden_size_global", 2048),
|
||||
"pm_size": unified_config.get("pm_size", 0),
|
||||
"hidden_size": encoder_hidden_size,
|
||||
"num_attention_heads": unified_config.get("n_heads_local_encoder", 16),
|
||||
"num_key_value_heads": unified_config.get("n_kv_heads"),
|
||||
"num_hidden_layers": unified_config.get("n_layers_local_encoder", 1),
|
||||
"norm_eps": unified_config.get("norm_eps", 1e-5),
|
||||
"dropout": unified_config.get("dropout", 0.0),
|
||||
"max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
|
||||
"rope_theta": unified_config.get("rope_theta", 10000.0),
|
||||
"rope_scaling": {"rope_type": "default"},
|
||||
"hidden_act": unified_config.get("hidden_act", "silu"),
|
||||
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
|
||||
"intermediate_size": encoder_intermediate_size,
|
||||
}
|
||||
|
||||
# Create decoder config
|
||||
decoder_hidden_size = unified_config.get("dim_local_decoder", 1024)
|
||||
decoder_multiple_of = unified_config.get("multiple_of", 256)
|
||||
decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of)
|
||||
|
||||
decoder_config = {
|
||||
"vocab_size": unified_config.get("vocab_size", 256),
|
||||
"cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False),
|
||||
"cross_attn_k": unified_config.get("cross_attn_k", 2),
|
||||
"hidden_size_global": unified_config.get("hidden_size_global", 2048),
|
||||
"hidden_size": decoder_hidden_size,
|
||||
"num_attention_heads": unified_config.get("n_heads_local_decoder", 16),
|
||||
"num_key_value_heads": unified_config.get("n_kv_heads"),
|
||||
"num_hidden_layers": unified_config.get("n_layers_local_decoder", 9),
|
||||
"norm_eps": unified_config.get("norm_eps", 1e-5),
|
||||
"dropout": unified_config.get("dropout", 0.0),
|
||||
"max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
|
||||
"rope_theta": unified_config.get("rope_theta", 10000.0),
|
||||
"rope_scaling": {"rope_type": "default"},
|
||||
"hidden_act": unified_config.get("hidden_act", "silu"),
|
||||
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
|
||||
"intermediate_size": decoder_intermediate_size,
|
||||
}
|
||||
|
||||
# Create global transformer config
|
||||
global_hidden_size = unified_config.get("dim_global", 2048)
|
||||
global_multiple_of = unified_config.get("multiple_of", 256)
|
||||
global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of)
|
||||
|
||||
global_config = {
|
||||
"hidden_size": global_hidden_size,
|
||||
"num_attention_heads": unified_config.get("n_heads_global", 16),
|
||||
"num_key_value_heads": unified_config.get("n_kv_heads_global"),
|
||||
"num_hidden_layers": unified_config.get("n_layers_global", 25),
|
||||
"norm_eps": unified_config.get("norm_eps", 1e-5),
|
||||
"dropout": unified_config.get("dropout", 0.0),
|
||||
"max_position_embeddings": unified_config.get("max_seqlen", 1024),
|
||||
"rope_theta": unified_config.get("rope_theta", 10000.0),
|
||||
"rope_scaling": {"rope_type": "default"},
|
||||
"hidden_act": unified_config.get("hidden_act", "silu"),
|
||||
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
|
||||
"intermediate_size": global_intermediate_size,
|
||||
}
|
||||
|
||||
# Create main config with sub-configs
|
||||
main_config_dict = {
|
||||
"model_type": "blt",
|
||||
"vocab_size": unified_config.get("vocab_size", 256),
|
||||
"max_position_embeddings": unified_config.get("max_seqlen", 1024),
|
||||
"patch_in_forward": True,
|
||||
"realtime_patching": True,
|
||||
"patching_mode": "entropy",
|
||||
"patch_size": patch_size,
|
||||
"patching_threshold": patcher_args.get("threshold", 0.5),
|
||||
"patching_threshold_add": patcher_args.get("threshold_add", 0.0),
|
||||
"max_patch_length": patcher_args.get("max_patch_length"),
|
||||
"patching_batch_size": patcher_args.get("patching_batch_size", 1),
|
||||
"patching_device": patcher_args.get("patching_device", "cuda"),
|
||||
"monotonicity": patcher_args.get("monotonicity", False),
|
||||
"cross_attn_k": unified_config.get("cross_attn_k", 2),
|
||||
"encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"),
|
||||
"encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000),
|
||||
"encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3),
|
||||
"pm_size": unified_config.get("pm_size", 0),
|
||||
"patcher_config": patcher_config,
|
||||
"encoder_config": encoder_config,
|
||||
"decoder_config": decoder_config,
|
||||
"global_config": global_config,
|
||||
}
|
||||
|
||||
main_config_dict["tie_word_embeddings"] = False
|
||||
|
||||
logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
|
||||
return main_config_dict
|
||||
|
||||
|
||||
def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
component_mappings = {
|
||||
".attention.": ".self_attn.",
|
||||
".feed_forward.": ".mlp.",
|
||||
".attention_norm.": ".input_layernorm.",
|
||||
".ffn_norm.": ".post_attention_layernorm.",
|
||||
".tok_embeddings.": ".embed_tokens.",
|
||||
".cross_attn_norm_q.": ".q_norm.",
|
||||
".cross_attn_norm_kv.": ".k_norm.",
|
||||
".w1.": ".gate_proj.",
|
||||
".w2.": ".down_proj.",
|
||||
".w3.": ".up_proj.",
|
||||
".wq.": ".q_proj.",
|
||||
".wk.": ".k_proj.",
|
||||
".wv.": ".v_proj.",
|
||||
".wo.": ".o_proj.",
|
||||
".output.": ".lm_head.",
|
||||
}
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
for old_key, tensor in state_dict.items():
|
||||
new_key = old_key
|
||||
|
||||
for old_pattern, new_pattern in component_mappings.items():
|
||||
if old_pattern in new_key:
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
|
||||
new_state_dict[new_key] = tensor
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
|
||||
main_weights = load_file(weights_path)
|
||||
|
||||
entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True)
|
||||
|
||||
if "model" in entropy_weights:
|
||||
entropy_weights = entropy_weights["model"]
|
||||
elif "state_dict" in entropy_weights:
|
||||
entropy_weights = entropy_weights["state_dict"]
|
||||
|
||||
unified_weights = main_weights.copy()
|
||||
|
||||
for key, tensor in entropy_weights.items():
|
||||
patcher_key = f"patcher.{key}"
|
||||
unified_weights[patcher_key] = tensor
|
||||
|
||||
unified_weights = apply_weight_mapping(unified_weights)
|
||||
|
||||
decoder_lm_head_key = "local_decoder.lm_head.weight"
|
||||
top_lm_head_key = "lm_head.weight"
|
||||
unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
|
||||
del unified_weights[decoder_lm_head_key]
|
||||
|
||||
prefixed_weights = {}
|
||||
for key, tensor in unified_weights.items():
|
||||
if key == top_lm_head_key:
|
||||
prefixed_weights[key] = tensor
|
||||
elif not key.startswith("model."):
|
||||
prefixed_weights[f"model.{key}"] = tensor
|
||||
else:
|
||||
prefixed_weights[key] = tensor
|
||||
|
||||
unified_weights = prefixed_weights
|
||||
|
||||
return unified_weights
|
||||
|
||||
|
||||
def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
|
||||
tokenizer_config = {
|
||||
"tokenizer_class": "BltTokenizer",
|
||||
"vocab_size": config.get("vocab_size", 256),
|
||||
"model_max_length": config.get("max_seqlen", 1024),
|
||||
"add_bos_token": True,
|
||||
"add_eos_token": True,
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
|
||||
with open(tokenizer_path, "w") as f:
|
||||
json.dump(tokenizer_config, f, indent=2)
|
||||
|
||||
|
||||
def push_to_hub(
|
||||
local_dir: str,
|
||||
repo_id: str,
|
||||
commit_message: str = "Upload converted BLT model",
|
||||
private: bool = False,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
try:
|
||||
upload_folder(
|
||||
folder_path=local_dir,
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
)
|
||||
logger.info(f"Successfully pushed model to {repo_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push model to Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_hf_blt_to_unified(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
config_name: str = "config.json",
|
||||
weights_name: str = "model.bin",
|
||||
cache_dir: Optional[str] = None,
|
||||
push_to_hub_repo: Optional[str] = None,
|
||||
hub_private: bool = False,
|
||||
hub_token: Optional[str] = None,
|
||||
) -> None:
|
||||
# Download model files
|
||||
config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
|
||||
weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
|
||||
entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir)
|
||||
entropy_weights_path = hf_hub_download(
|
||||
repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir
|
||||
)
|
||||
|
||||
unified_config = merge_configurations(config_path, entropy_params_path)
|
||||
unified_weights = merge_weights(weights_path, entropy_weights_path)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
config_path = os.path.join(output_dir, config_name)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(unified_config, f, indent=2)
|
||||
|
||||
if weights_name.endswith(".bin"):
|
||||
weights_name = weights_name.replace(".bin", ".safetensors")
|
||||
|
||||
weights_path = os.path.join(output_dir, weights_name)
|
||||
save_file(unified_weights, weights_path)
|
||||
|
||||
create_tokenizer_config(output_dir, unified_config)
|
||||
|
||||
logger.info(f"Conversion completed, model saved to: {output_dir}")
|
||||
|
||||
if push_to_hub_repo:
|
||||
push_to_hub(
|
||||
local_dir=output_dir,
|
||||
repo_id=push_to_hub_repo,
|
||||
commit_message="Upload BLT model converted",
|
||||
private=hub_private,
|
||||
token=hub_token,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert BLT models from HuggingFace Hub format to unified format",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="facebook/blt-1b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="./blt_converted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="config.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weights_name",
|
||||
type=str,
|
||||
default="model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_private",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default="hf_token",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
transformers_logging.set_verbosity_debug()
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
try:
|
||||
convert_hf_blt_to_unified(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
config_name=args.config_name,
|
||||
weights_name=args.weights_name,
|
||||
cache_dir=args.cache_dir,
|
||||
push_to_hub_repo=args.push_to_hub,
|
||||
hub_private=args.hub_private,
|
||||
hub_token=args.hub_token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Conversion failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
tests/models/blt/__init__.py
Normal file
0
tests/models/blt/__init__.py
Normal file
930
tests/models/blt/test_modeling_blt.py
Normal file
930
tests/models/blt/test_modeling_blt.py
Normal file
@ -0,0 +1,930 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch BLT model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, BLTConfig, StaticCache, is_torch_available, set_seed
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BLTForCausalLM,
|
||||
BLTForQuestionAnswering,
|
||||
BLTForSequenceClassification,
|
||||
BLTForTokenClassification,
|
||||
BLTModel,
|
||||
BLTTokenizer,
|
||||
)
|
||||
from transformers.models.blt.modeling_blt import BLTRotaryEmbedding
|
||||
|
||||
|
||||
class BLTModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return BLTConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = BLTModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class BLTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
BLTModel,
|
||||
BLTForCausalLM,
|
||||
BLTForSequenceClassification,
|
||||
BLTForQuestionAnswering,
|
||||
BLTForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = BLTForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BLTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BLTConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_blt_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = BLTForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_blt_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "single_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = BLTForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_blt_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "multi_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = BLTForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_blt_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = BLTForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = BLTModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = BLTModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
def test_model_rope_scaling(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
scaling_factor = 10
|
||||
short_input_length = 10
|
||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(
|
||||
1, dtype=torch.float32, device=torch_device
|
||||
) # used exclusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = BLTRotaryEmbedding(config=config).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||
linear_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
# with scaling_factor (or that `inv_freq` decreases)
|
||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||
ntk_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
# Scaling should be over the entire input
|
||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||
yarn_scaling_rope = BLTRotaryEmbedding(config=config).to(torch_device)
|
||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
def test_model_loading_old_rope_configs(self):
|
||||
def _reinitialize_config(base_config, new_kwargs):
|
||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||
# steps.
|
||||
base_config_dict = base_config.to_dict()
|
||||
new_config = BLTConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
|
||||
return new_config
|
||||
|
||||
# from untouched config -> ✅
|
||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_model = BLTForCausalLM(base_config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the expected rope configuration -> ✅
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||
original_model = BLTForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||
original_model = BLTForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
|
||||
)
|
||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||
original_model = BLTForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||
original_model = BLTForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("factor field", logs.output[0])
|
||||
|
||||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||
)
|
||||
original_model = BLTForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Unrecognized keys", logs.output[0])
|
||||
|
||||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
|
||||
with self.assertRaises(KeyError):
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class BLTIntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
|
||||
# some memory allocated in the cache, which means some object is not being released properly. This causes some
|
||||
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
|
||||
# Investigate the root cause.
|
||||
cleanup(torch_device, gc_collect=False)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_blt_3_1_hard(self):
|
||||
"""
|
||||
An integration test for blt 3.1. It tests against a long output to ensure the subtle numerical differences
|
||||
from blt 3.1.'s RoPE can be detected
|
||||
"""
|
||||
# diff on `EXPECTED_TEXT`:
|
||||
# 2024-08-26: updating from torch 2.3.1 to 2.4.0 slightly changes the results.
|
||||
EXPECTED_TEXT = (
|
||||
"Tell me about the french revolution. The french revolution was a period of radical political and social "
|
||||
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
|
||||
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
|
||||
"First French Republic.\nThe revolution began in 1789 with the Estates-General, a representative "
|
||||
"assembly that had not met since 1614. The Third Estate, which represented the common people, "
|
||||
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
|
||||
"the beginning of the end of the absolute monarchy and the rise of the middle class.\n"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-blt/Meta-BLT-3.1-8B-Instruct")
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
"meta-blt/Meta-BLT-3.1-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
|
||||
)
|
||||
input_text = ["Tell me about the french revolution."]
|
||||
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
|
||||
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(generated_text, EXPECTED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_7b_logits_bf16(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
"meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
# Expected mean on dim = -1
|
||||
|
||||
# fmt: off
|
||||
expected_means = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]]),
|
||||
("cuda", 7): torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
|
||||
("cuda", 8): torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
|
||||
})
|
||||
|
||||
expected_mean = expected_means.get_expectation()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_mean.to(torch_device),
|
||||
out.logits.float().mean(-1),
|
||||
atol=1e-2,
|
||||
rtol=1e-2
|
||||
)
|
||||
)
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]]),
|
||||
("cuda", 7): torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
|
||||
("cuda", 8): torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
|
||||
})
|
||||
# fmt: on
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_slice.to(torch_device),
|
||||
out.logits[0, 0, :15].float(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
"meta-blt/BLT-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
|
||||
# fmt: off
|
||||
# Expected mean on dim = -1
|
||||
expected_means = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
|
||||
("cuda", 7): torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
|
||||
("cuda", 8): torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]]),
|
||||
})
|
||||
|
||||
expected_mean = expected_means.get_expectation()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_mean.to(torch_device),
|
||||
out.logits.float().mean(-1),
|
||||
atol=1e-2,
|
||||
rtol=1e-2
|
||||
)
|
||||
)
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328]),
|
||||
("cuda", 7): torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
|
||||
("cuda", 8): torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
|
||||
})
|
||||
# fmt: on
|
||||
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_slice.to(torch_device),
|
||||
out.logits[0, 0, :15].float(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
|
||||
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
|
||||
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
|
||||
"understanding of space and time."
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-chat-hf")
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
"meta-blt/BLT-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||
"theory of relativ",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = BLTTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
"meta-blt/BLT-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_static_cache(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
||||
from transformers.integrations.executorch import (
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
|
||||
blt_models = {
|
||||
"meta-blt/BLT-3.2-1B": [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all "
|
||||
"observers, regardless of their location, and 2) the laws of physics are the same for all observers"
|
||||
],
|
||||
}
|
||||
|
||||
for blt_model_ckp, EXPECTED_TEXT_COMPLETION in blt_models.items():
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(blt_model_ckp, pad_token="</s>", padding_side="right")
|
||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
batch_size = 1
|
||||
model = BLTForCausalLM.from_pretrained(
|
||||
blt_model_ckp,
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
generation_config=GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation=cache_implementation,
|
||||
max_length=max_generation_length,
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_generation_length,
|
||||
"device": device,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
prompts = ["Simply put, the theory of relativity states that "]
|
||||
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
prompt_token_ids = prompt_tokens["input_ids"]
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
model_name = "TinyBLT/TinyBLT-1.1B-Chat-v1.0"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = BLTTokenizer.from_pretrained(model_name)
|
||||
self.model = BLTForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
||||
|
||||
# building custom positions ids based on custom mask
|
||||
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
||||
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
|
||||
# inverting the mask
|
||||
min_dtype = torch.finfo(self.model_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
past_key_values=past_key_values_a,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
def test_stacked_causal_mask_static_cache(self):
|
||||
"""same as above but with StaticCache"""
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
padded_attention_mask = torch.nn.functional.pad(
|
||||
input=mask_shared_prefix,
|
||||
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix,
|
||||
attention_mask=padded_attention_mask,
|
||||
position_ids=position_ids_shared_prefix,
|
||||
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask_static_cache(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
# forward run for the first part of input
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
padded_mask_1a = torch.nn.functional.pad(
|
||||
input=mask_1a,
|
||||
pad=(0, max_cache_len - mask_1a.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
_ = self.model.forward(
|
||||
input_1a,
|
||||
attention_mask=padded_mask_1a,
|
||||
position_ids=position_ids_1a,
|
||||
cache_position=torch.arange(part_a, device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
# forward run for the second part of input
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
|
||||
padded_mask_1b = torch.nn.functional.pad(
|
||||
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
|
||||
)
|
||||
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=padded_mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
cache_position=torch.arange(
|
||||
part_a,
|
||||
input_ids_shared_prefix.shape[-1],
|
||||
device=torch_device,
|
||||
),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
914
tests/models/blt/test_tokenization_blt.py
Normal file
914
tests/models/blt/test_tokenization_blt.py
Normal file
@ -0,0 +1,914 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
SPIECE_UNDERLINE,
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
BLTTokenizer,
|
||||
BLTTokenizerFast,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_jinja,
|
||||
require_read_token,
|
||||
require_sentencepiece,
|
||||
require_tiktoken,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class BLTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = ["hf-internal-testing/blt-tokenizer", "meta-blt/BLT-2-7b-hf"]
|
||||
tokenizer_class = BLTTokenizer
|
||||
rust_tokenizer_class = BLTTokenizerFast
|
||||
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = True
|
||||
from_pretrained_kwargs = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizers(self, **kwargs):
|
||||
kwargs.update({"pad_token": "<PAD>"})
|
||||
return super().get_tokenizers(**kwargs)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = BLTTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[285, 46, 10, 170, 382],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Let's wait for the fast tokenizer!")
|
||||
def test_save_pretrained(self):
|
||||
self.tokenizers_list += (self.rust_tokenizer_class, "hf-internal-testing/blt-tokenizer", {})
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs)
|
||||
|
||||
tmpdirname2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
|
||||
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
||||
|
||||
# Checks it save with the same files + the tokenizer.json file for the fast one
|
||||
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
|
||||
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
|
||||
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
|
||||
|
||||
# Checks everything loads correctly in the same way
|
||||
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
||||
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
||||
|
||||
# Check special tokens are set accordingly on Rust and Python
|
||||
for key in tokenizer_pp.special_tokens_map:
|
||||
self.assertTrue(hasattr(tokenizer_rp, key))
|
||||
|
||||
shutil.rmtree(tmpdirname2)
|
||||
|
||||
# Save tokenizer rust, legacy_format=True
|
||||
tmpdirname2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True)
|
||||
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
||||
|
||||
# Checks it save with the same files
|
||||
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
|
||||
|
||||
# Checks everything loads correctly in the same way
|
||||
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
||||
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
||||
|
||||
# Check special tokens are set accordingly on Rust and Python
|
||||
for key in tokenizer_pp.special_tokens_map:
|
||||
self.assertTrue(hasattr(tokenizer_rp, key))
|
||||
|
||||
shutil.rmtree(tmpdirname2)
|
||||
|
||||
# Save tokenizer rust, legacy_format=False
|
||||
tmpdirname2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False)
|
||||
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
||||
|
||||
# Checks it saved the tokenizer.json file
|
||||
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
|
||||
|
||||
# Checks everything loads correctly in the same way
|
||||
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
||||
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
||||
|
||||
# Check special tokens are set accordingly on Rust and Python
|
||||
for key in tokenizer_pp.special_tokens_map:
|
||||
self.assertTrue(hasattr(tokenizer_rp, key))
|
||||
|
||||
shutil.rmtree(tmpdirname2)
|
||||
|
||||
@require_torch
|
||||
def test_batch_tokenization(self):
|
||||
if not self.test_seq2seq:
|
||||
self.skipTest(reason="test_seq2seq is set to False")
|
||||
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
# Longer text that will definitely require truncation.
|
||||
text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
|
||||
" Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
|
||||
" will only worsen the violence and misery for millions of people.",
|
||||
]
|
||||
try:
|
||||
batch = tokenizer(
|
||||
text=text,
|
||||
max_length=3,
|
||||
max_target_length=10,
|
||||
return_tensors="pt",
|
||||
)
|
||||
except NotImplementedError:
|
||||
self.skipTest(reason="Encountered NotImplementedError when calling tokenizer")
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer(text, max_length=3, return_tensors="pt")
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
|
||||
batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt")
|
||||
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
|
||||
self.assertNotIn("decoder_input_ids", batch_encoder_only)
|
||||
|
||||
@unittest.skip(reason="Unfortunately way too slow to build a BPE with SentencePiece.")
|
||||
def test_save_slow_from_fast_and_reload_fast(self):
|
||||
pass
|
||||
|
||||
def test_special_tokens_initialization(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.get_rust_tokenizer(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||
|
||||
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||
|
||||
self.assertTrue(special_token_id in r_output)
|
||||
|
||||
if self.test_slow_tokenizer:
|
||||
tokenizer_cr = self.get_rust_tokenizer(
|
||||
pretrained_name,
|
||||
additional_special_tokens=added_tokens,
|
||||
**kwargs, # , from_slow=True <- unfortunately too slow to convert
|
||||
)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
|
||||
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
||||
|
||||
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
||||
|
||||
self.assertEqual(p_output, r_output)
|
||||
self.assertEqual(cr_output, r_output)
|
||||
self.assertTrue(special_token_id in p_output)
|
||||
self.assertTrue(special_token_id in cr_output)
|
||||
|
||||
@slow
|
||||
def test_tokenizer_integration(self):
|
||||
expected_encoding = {'input_ids': [[1, 4103, 689, 414, 313, 24784, 368, 2998, 408, 282, 3637, 25350, 29899, 9067, 414, 322, 282, 3637, 25350, 29899, 1457, 3018, 1312, 29899, 2151, 29897, 8128, 2498, 29899, 15503, 4220, 6956, 1973, 313, 13635, 29911, 29892, 402, 7982, 29899, 29906, 29892, 1528, 13635, 29911, 29874, 29892, 1060, 26369, 29892, 6652, 309, 29933, 814, 29892, 1060, 29931, 6779, 11410, 363, 18385, 17088, 7634, 11235, 313, 25103, 29965, 29897, 322, 18385, 17088, 28203, 313, 25103, 29954, 29897, 411, 975, 29871, 29941, 29906, 29974, 758, 3018, 1312, 4733, 297, 29871, 29896, 29900, 29900, 29974, 10276, 322, 6483, 1006, 3372, 3097, 1546, 435, 1165, 29892, 10772, 29911, 25350, 322, 323, 6073, 17907, 29889], [1, 350, 20161, 338, 8688, 304, 758, 29899, 14968, 6483, 21000, 8684, 284, 22540, 515, 443, 29880, 24025, 1426, 491, 14002, 368, 4195, 292, 373, 1716, 2175, 322, 1492, 3030, 297, 599, 15359, 29889], [1, 450, 4996, 17354, 1701, 29916, 432, 17204, 975, 278, 17366, 11203, 29889]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
|
||||
|
||||
self.tokenizer_integration_test_util(
|
||||
expected_encoding=expected_encoding,
|
||||
model_name="hf-internal-testing/blt-tokenizer",
|
||||
revision="0984d03108b1a041ed679bd253b6519b7e1a4778",
|
||||
padding=False,
|
||||
)
|
||||
|
||||
def test_picklable(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
||||
tokenizer = BLTTokenizer(f.name, keep_accents=True)
|
||||
pickled_tokenizer = pickle.dumps(tokenizer)
|
||||
pickle.loads(pickled_tokenizer)
|
||||
|
||||
@unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.")
|
||||
def test_pickle_subword_regularization_tokenizer(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.")
|
||||
def test_subword_regularization_tokenizer(self):
|
||||
pass
|
||||
|
||||
def test_add_prefix_space(self):
|
||||
pretrained_name = "hf-internal-testing/blt-tokenizer-non-normalized"
|
||||
inputs = "Hey how are you doing"
|
||||
EXPECTED_WITH_SPACE = [1, 18637, 920, 526, 366, 2599]
|
||||
EXPECTED_WO_SPACE = [1, 29950, 1032, 920, 526, 366, 2599]
|
||||
|
||||
slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=False, legacy=False)
|
||||
fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=False, legacy=False)
|
||||
self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE)
|
||||
self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
|
||||
self.assertEqual(slow_.tokenize(inputs), ["H", "ey", "▁how", "▁are", "▁you", "▁doing"])
|
||||
self.assertEqual(slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), inputs)
|
||||
self.assertEqual(
|
||||
slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
|
||||
fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
|
||||
)
|
||||
|
||||
slow_ = self.get_tokenizer(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
fast_ = self.get_rust_tokenizer(pretrained_name, add_prefix_space=True, legacy=False)
|
||||
self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE)
|
||||
self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
|
||||
self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"])
|
||||
self.assertEqual(slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), inputs)
|
||||
self.assertEqual(
|
||||
slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
|
||||
fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
|
||||
)
|
||||
|
||||
def test_load_tokenizer_with_model_file_only(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
hf_hub_download(repo_id="huggyblt/blt-7b", filename="tokenizer.model", local_dir=tmp_dir)
|
||||
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir)
|
||||
self.assertEqual(tokenizer_fast.encode("This is a test"), [1, 910, 338, 263, 1243])
|
||||
|
||||
tokenizer_slow = self.tokenizer_class.from_pretrained(tmp_dir)
|
||||
self.assertEqual(tokenizer_slow.encode("This is a test"), [1, 910, 338, 263, 1243])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class BLTIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "hf-internal-testing/blt-tokenizer-non-normalized"
|
||||
cls.tokenizer: BLTTokenizer = BLTTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.rust_tokenizer = BLTTokenizerFast.from_pretrained(checkpoint_name)
|
||||
return cls
|
||||
|
||||
@require_torch
|
||||
def integration_tests(self):
|
||||
inputs = self.tokenizer(
|
||||
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889],
|
||||
[1, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718],
|
||||
],
|
||||
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
},
|
||||
)
|
||||
|
||||
def test_fast_special_tokens(self):
|
||||
slow_tokenizer = self.tokenizer
|
||||
fast_tokenizer = self.rust_tokenizer
|
||||
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [1, 319, 4559, 1243]
|
||||
|
||||
fast_tokenizer.add_eos_token = False
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [1, 319, 4559, 1243]
|
||||
|
||||
fast_tokenizer.add_eos_token = True
|
||||
print(fast_tokenizer.add_eos_token)
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [1, 319, 4559, 1243, 2]
|
||||
|
||||
slow_tokenizer.add_eos_token = True
|
||||
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [1, 319, 4559, 1243, 2]
|
||||
|
||||
fast_tokenizer = BLTTokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False
|
||||
)
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [319, 4559, 1243, 2]
|
||||
|
||||
slow_tokenizer = BLTTokenizer.from_pretrained(
|
||||
"hf-internal-testing/blt-tokenizer", add_eos_token=True, add_bos_token=False
|
||||
)
|
||||
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [319, 4559, 1243, 2]
|
||||
|
||||
self.tokenizer.add_eos_token = False
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
@slow
|
||||
def test_conversion(self):
|
||||
# This is excruciatingly slow since it has to recreate the entire merge
|
||||
# list from the original vocabulary in spm
|
||||
self.rust_tokenizer.save_pretrained("./out")
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
self.rust_tokenizer.save_pretrained(dirname)
|
||||
|
||||
with open(os.path.join(dirname, "tokenizer.json")) as f:
|
||||
old_serialized = f.read()
|
||||
|
||||
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
new_tokenizer.save(f.name)
|
||||
# Re-opening since `f` is in bytes.
|
||||
new_serialized = open(f.name).read()
|
||||
with open("out_tokenizer.json", "w") as g:
|
||||
g.write(new_serialized)
|
||||
|
||||
self.assertEqual(old_serialized, new_serialized)
|
||||
|
||||
def test_simple_encode_decode(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
|
||||
self.assertEqual(rust_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
|
||||
self.assertEqual(pyth_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
|
||||
self.assertEqual(rust_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
|
||||
|
||||
# bytefallback showcase
|
||||
self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip
|
||||
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip
|
||||
self.assertEqual(
|
||||
pyth_tokenizer.decode(
|
||||
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
|
||||
),
|
||||
"生活的真谛是",
|
||||
)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(
|
||||
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
|
||||
),
|
||||
"生活的真谛是",
|
||||
)
|
||||
|
||||
# Inner spaces showcase
|
||||
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
|
||||
self.assertEqual(pyth_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
|
||||
self.assertEqual(rust_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
|
||||
self.assertEqual(pyth_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
|
||||
self.assertEqual(rust_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
|
||||
def test_no_differences_showcase(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
||||
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||
|
||||
def test_no_differences_decode(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(pyth_tokenizer.decode([869]), ".")
|
||||
self.assertEqual(rust_tokenizer.decode([869]), ".")
|
||||
|
||||
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
|
||||
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
|
||||
|
||||
def test_no_differences_special_tokens(self):
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
||||
self.assertEqual(rust_tokenizer.encode(""), [1])
|
||||
|
||||
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
||||
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
|
||||
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
|
||||
)
|
||||
def test_integration_test_xnli(self):
|
||||
import tqdm
|
||||
|
||||
pyth_tokenizer = self.tokenizer
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go")
|
||||
for item in tqdm.tqdm(dataset["validation"]):
|
||||
string = item["code"]
|
||||
encoded1 = pyth_tokenizer.encode(string)
|
||||
encoded2 = rust_tokenizer.encode(string)
|
||||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
||||
dataset = load_dataset("facebook/xnli", "all_languages")
|
||||
|
||||
for item in tqdm.tqdm(dataset["train"]):
|
||||
for string in item["premise"].values():
|
||||
encoded1 = pyth_tokenizer.encode(string)
|
||||
encoded2 = rust_tokenizer.encode(string)
|
||||
|
||||
self.assertEqual(encoded1, encoded2)
|
||||
|
||||
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
||||
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(decoded1, decoded2)
|
||||
|
||||
def test_special_token_special_word(self):
|
||||
# the word inform should be split as ['in', 'form']
|
||||
tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True)
|
||||
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
|
||||
|
||||
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
|
||||
# Make sure dummy space is added if it is indeed the first word
|
||||
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["▁inform", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
out1 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
self.assertEqual(out1, "<REPR_END>inform")
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
|
||||
)
|
||||
# decoding strips the added prefix space.
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
|
||||
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
# TODO @ArthurZ currently we strip left and right, so this will not keep the spaces
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
|
||||
### Let's make sure decoding does not add extra spaces here and there
|
||||
# TODO @ArthurZ this should be affected by the lstrip/rstrip/single word /normalize refactoring
|
||||
# Since currently we always strip left and right of the token, results are as such
|
||||
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [1, 15043, 1, 3525])
|
||||
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
|
||||
|
||||
# Let's make sure that if there are any spaces, we don't remove them!
|
||||
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
|
||||
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["▁", "<s>", "▁Hello", "<s>", "▁how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s> Hello<s> how")
|
||||
|
||||
# Let's make sure the space is preserved
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=True)
|
||||
self.assertEqual(input_ids, [1, 22172])
|
||||
tokens = tokenizer.tokenize("hello")
|
||||
self.assertEqual(tokens, ["▁hello"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s> hello")
|
||||
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [22172])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "hello")
|
||||
|
||||
def test_no_prefix_space(self):
|
||||
tokenizer_no_prefix_space = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", add_prefix_space=False)
|
||||
no_prefix_space_tokens = tokenizer_no_prefix_space.tokenize("Hey")
|
||||
self.assertEqual(no_prefix_space_tokens, ["H", "ey"])
|
||||
|
||||
tokenizer = BLTTokenizerFast.from_pretrained(
|
||||
"huggyblt/blt-7b", legacy=False, from_slow=True, add_prefix_space=False
|
||||
)
|
||||
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
|
||||
|
||||
example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["<REPR_END>", "in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
|
||||
# Make sure dummy space is added if it is indeed the first word
|
||||
example_inputs = tokenizer.tokenize("inform<s>. Hey. .")
|
||||
self.assertEqual(example_inputs, ["in", "form", "<s>", ".", "▁Hey", ".", "▁▁▁▁▁▁", "▁."])
|
||||
out1 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
self.assertEqual(out1, "<REPR_END>inform")
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
|
||||
)
|
||||
# decoding strips the added prefix space.
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [32000, 262, 689]) # 29871 is the spiece underline, '▁' added as it should
|
||||
|
||||
out2 = tokenizer.decode(
|
||||
tokenizer.encode(" <REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
||||
)
|
||||
self.assertEqual(out2, "<REPR_END>inform")
|
||||
|
||||
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [1, 15043, 1, 3525])
|
||||
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
|
||||
|
||||
# Let's make sure that if there are any spaces, we don't remove them!
|
||||
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [29871, 1, 15043, 1, 920])
|
||||
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
|
||||
self.assertEqual(tokens, ["▁", "<s>", "▁Hello", "<s>", "▁how"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||
|
||||
# Let's make sure the space is preserved
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=True)
|
||||
self.assertEqual(input_ids, [1, 12199])
|
||||
tokens = tokenizer.tokenize("hello")
|
||||
self.assertEqual(tokens, ["hello"])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "<s>hello")
|
||||
|
||||
input_ids = tokenizer.encode("hello", add_special_tokens=False)
|
||||
self.assertEqual(input_ids, [12199])
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, "hello")
|
||||
|
||||
def test_some_edge_cases(self):
|
||||
tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False)
|
||||
|
||||
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
|
||||
self.assertEqual(sp_tokens, ["<", "s", ">>"])
|
||||
tokens = tokenizer.tokenize("<s>>")
|
||||
self.assertNotEqual(sp_tokens, tokens)
|
||||
self.assertEqual(tokens, ["<s>", ">"])
|
||||
|
||||
tokens = tokenizer.tokenize("")
|
||||
self.assertEqual(tokens, [])
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
|
||||
|
||||
tokens = tokenizer.tokenize(" ")
|
||||
self.assertEqual(tokens, ["▁▁"])
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str))
|
||||
|
||||
tokens = tokenizer.tokenize("▁")
|
||||
self.assertEqual(tokens, ["▁▁"])
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁", out_type=str))
|
||||
|
||||
tokens = tokenizer.tokenize(" ▁")
|
||||
self.assertEqual(tokens, ["▁▁▁"])
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
||||
|
||||
def test_fast_post_processor(self):
|
||||
tokenizer = BLTTokenizerFast(
|
||||
SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False
|
||||
)
|
||||
tokenizer.encode(" Hey ")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer = BLTTokenizerFast(
|
||||
SAMPLE_VOCAB, bos_token=None, eos_token="<s>", add_bos_token=True, add_eos_token=False
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer = BLTTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = BLTTokenizer.from_pretrained("huggyblt/blt-7b", legacy=False)
|
||||
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
# Matt: The third test case tests the default system message, but if this is ever changed in the
|
||||
# class/repo code then that test will fail, and the case will need to be updated.
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
|
||||
[1, 29961, 25580, 29962, 15043, 29991, 518, 29914, 25580, 29962]
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
A class that regroups important test to make sure that we properly handle the special tokens.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
tokenizer = BLTTokenizer(SAMPLE_VOCAB, extra_ids=0, add_bos_token=False, legacy=False)
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("<s>", rstrip=False, lstrip=False)]})
|
||||
cls.tokenizer = tokenizer
|
||||
return cls
|
||||
|
||||
def test_add_dummy_prefix(self):
|
||||
# make sure `'▁'` is prepended, and outputs match sp_model's
|
||||
# `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
|
||||
input_ids = self.tokenizer.encode(". Hello")
|
||||
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
|
||||
sp_encode = self.tokenizer.sp_model.encode(". Hello")
|
||||
self.assertEqual(input_ids, [7] + sp_encode)
|
||||
tokens = self.tokenizer.tokenize(". Hello")
|
||||
self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])
|
||||
|
||||
tokens = self.tokenizer.tokenize("")
|
||||
self.assertEqual(tokens, [])
|
||||
self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str))
|
||||
|
||||
tokens = self.tokenizer.tokenize(" ")
|
||||
self.assertEqual(tokens, [])
|
||||
self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str))
|
||||
|
||||
tokens = self.tokenizer.tokenize("▁")
|
||||
self.assertEqual(tokens, [])
|
||||
self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str))
|
||||
|
||||
def test_remove_extra_whitespaces(self):
|
||||
# make sure the extra spaces are eaten. Since the sample vocab does not have
|
||||
# `______`. sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute is set to False
|
||||
|
||||
input_ids = self.tokenizer.encode(" . Hello")
|
||||
self.assertEqual(input_ids, [7, 4, 156, 86, 20])
|
||||
sp_encode = self.tokenizer.sp_model.encode(" . Hello")
|
||||
self.assertEqual(input_ids, [7] + sp_encode)
|
||||
tokens = self.tokenizer.tokenize(" . Hello")
|
||||
self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])
|
||||
|
||||
# `'▁'` is also a whitespace
|
||||
input_ids = self.tokenizer.encode("▁He is not")
|
||||
self.assertEqual(input_ids, [156, 46, 44])
|
||||
tokens = self.tokenizer.tokenize("▁He is not")
|
||||
sp_encode = [
|
||||
self.tokenizer.sp_model.piece_to_id("▁He"),
|
||||
self.tokenizer.sp_model.piece_to_id("▁is"),
|
||||
self.tokenizer.sp_model.piece_to_id("▁not"),
|
||||
]
|
||||
self.assertEqual(input_ids, sp_encode)
|
||||
self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added
|
||||
|
||||
input_ids = self.tokenizer.encode("▁He is not<s> ▁He")
|
||||
self.assertEqual(input_ids, [156, 46, 44, 1, 156])
|
||||
tokens = self.tokenizer.tokenize("▁He is not<s> ▁He")
|
||||
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "<s>", "▁He"]) # spaces are eaten by spm + our strip
|
||||
# make sure that the output after the extra id is the same as if
|
||||
# extra_id was not there
|
||||
input_ids = self.tokenizer.encode("▁He is not ▁He")
|
||||
self.assertEqual(input_ids, [156, 46, 44, 156])
|
||||
tokens = self.tokenizer.tokenize("▁He is not ▁He")
|
||||
self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start
|
||||
|
||||
def test_character_after_special_token(self):
|
||||
# Make sure that `tokenizer.tokenize` is similar to
|
||||
# adding the equivalent special token to the vocab
|
||||
input_ids = self.tokenizer.encode("Hey <s>I")
|
||||
self.assertEqual(input_ids, [156, 30, 1, 100])
|
||||
sp_encode = self.tokenizer.sp_model.encode("Hey .I")
|
||||
# the last token should be 100
|
||||
self.assertEqual(input_ids[-1], sp_encode[-1])
|
||||
tokens = self.tokenizer.tokenize("<s>I")
|
||||
self.assertEqual(tokens, ["<s>", "I"])
|
||||
|
||||
input_ids = self.tokenizer.encode("Hello, <s>,")
|
||||
self.assertEqual(input_ids, [156, 86, 20, 3, 1, 3])
|
||||
tokens = self.tokenizer.tokenize("Hello, <s>,")
|
||||
self.assertEqual(tokens, ["▁He", "ll", "o", ",", "<s>", ","])
|
||||
|
||||
def test_special_tokens_strip(self):
|
||||
input_ids = self.tokenizer.encode(" <s> ,")
|
||||
self.assertEqual(input_ids, [1, 7, 3])
|
||||
tokens = self.tokenizer.tokenize(" <s> ,")
|
||||
# spaces are eaten by rstrip / lstrip + spm sp_model.encode(" ") = []
|
||||
self.assertEqual(tokens, ["<s>", "▁", ","])
|
||||
|
||||
input_ids = self.tokenizer.encode("No <s> ▁He")
|
||||
self.assertEqual(input_ids, [284, 1, 156])
|
||||
tokens = self.tokenizer.tokenize("No <s> ▁He")
|
||||
self.assertEqual(tokens, ["▁No", "<s>", "▁He"]) # spaces are eaten by rstrip / lstrip
|
||||
|
||||
|
||||
@require_tiktoken
|
||||
@require_read_token
|
||||
class TikTokenIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
A class that regroups important test to make sure that we properly handle the special tokens.
|
||||
"""
|
||||
|
||||
def test_tiktoken_blt(self):
|
||||
model_path = "hf-internal-testing/blt-3-8b-internal"
|
||||
subfolder = "original"
|
||||
test_text = "This is a test sentence."
|
||||
test_tokens = [128000, 2028, 374, 264, 1296, 11914, 13, 128001]
|
||||
num_reserved_special_tokens = 256
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>",
|
||||
"<|python_tag|>", # end of turn
|
||||
] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)]
|
||||
|
||||
tiktoken_tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
additional_special_tokens=special_tokens,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>",
|
||||
)
|
||||
tokens = tiktoken_tokenizer.tokenize("<|begin_of_text|> " + test_text)
|
||||
self.assertEqual(tokens[0], "<|begin_of_text|>")
|
||||
|
||||
tiktoken_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
legacy=False,
|
||||
additional_special_tokens=special_tokens,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>",
|
||||
add_bos_token=True,
|
||||
add_eos_token=True,
|
||||
)
|
||||
self.assertTrue(isinstance(tiktoken_tokenizer, PreTrainedTokenizerFast))
|
||||
|
||||
tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True)
|
||||
self.assertEqual(tokens, test_tokens)
|
||||
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
tiktoken_tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer_reload = AutoTokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(isinstance(tokenizer_reload, PreTrainedTokenizerFast))
|
||||
tokens = tokenizer_reload.encode(test_text, add_special_tokens=True)
|
||||
self.assertEqual(tokens, test_tokens)
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
tiktoken_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
additional_special_tokens=special_tokens,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|end_of_text|>",
|
||||
from_slow=True,
|
||||
add_bos_token=True,
|
||||
add_eos_token=True,
|
||||
)
|
||||
tokens = tiktoken_tokenizer.encode(test_text, add_special_tokens=True)
|
||||
self.assertEqual(tokens, test_tokens)
|
Loading…
Reference in New Issue
Block a user