change MiniCPMOTokenizerFast to Qwen2TokenizerFast & fit new WhisperAttention

This commit is contained in:
GQN 2025-07-02 17:47:04 +08:00
parent 49391cde01
commit afae5f7767
4 changed files with 26 additions and 131 deletions

View File

@ -347,7 +347,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mgp-str", ("MgpstrTokenizer", None)),
("minicpm_o_2_6", ("Qwen2Tokenizer", "MiniCPM_o_2_6TokenizerFast" if is_tokenizers_available() else None)),
("minicpm_o_2_6", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
(
"minimax",
(

View File

@ -49,7 +49,7 @@ from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWa
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipPreTrainedModel
from transformers.models.idefics2.modeling_idefics2 import Idefics2Encoder
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder
from transformers.models.whisper.modeling_whisper import WhisperAttention, WhisperConfig, WhisperEncoder
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.integrations import is_deepspeed_zero3_enabled
@ -659,7 +659,7 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
result_text = []
for result in result_ids:
result = result[result != 0]
if result[0] == tokenizer.bos_id:
if result[0] == tokenizer.bos_token_id:
result = result[1:]
if result[-1] in terminators:
result = result[:-1]
@ -1020,9 +1020,9 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
if isinstance(answer, list):
answer = [i.replace(tokenizer.tts_end, "") for i in answer]
answer = [i.replace("<|tts_eos|>", "") for i in answer]
else:
answer = answer.replace(tokenizer.tts_end, "")
answer = answer.replace("<|tts_eos|>", "")
if return_dict:
return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr)
@ -1074,7 +1074,7 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
else:
logger.error("Invalid content type:", c)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs)
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
if self.llm_generated:
if self.llm_generate_completed:
@ -1191,8 +1191,8 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
spk_start_idx = torch.where(input_ids[0] == tokenizer.convert_tokens_to_ids("<|spk_bos|>"))[0]
spk_end_idx = torch.where(input_ids[0] == tokenizer.convert_tokens_to_ids("<|spk_eos|>"))[0]
spk_bounds = [
torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
] # List[Tensor], (1,2)
@ -1872,7 +1872,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,

View File

@ -147,6 +147,10 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
super().__init__(image_processor, feature_extractor, tokenizer)
self.version = image_processor.version
self.audio_start_token = "<|audio_start|>"
self.audio_end_token = "<|audio_end|>"
self.spk_bos_token = "<|spk_bos|>"
self.spk_eos_token = "<|spk_eos|>"
def __call__(
self,
@ -209,11 +213,11 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
total_unk_len = 0
for _ in range(num_audio_chunks):
unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
place_holders += self.audio_start_token + self.image_processor.unk_token * unk_len + self.audio_end_token
total_unk_len += unk_len
audio_placeholder = place_holders
else:
audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
audio_placeholder = self.audio_start_token + self.image_processor.unk_token * output_lens + self.audio_end_token
return audio_placeholder
@ -321,9 +325,9 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
result_text = []
for result in output_ids:
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
if result[0] == self.tokenizer.bos_token_id:
result = result[1:]
if result[-1] == self.tokenizer.eos_id:
if result[-1] == self.tokenizer.eos_token_id:
result = result[:-1]
result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
return result_text
@ -337,10 +341,10 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
"""
result = args[0]
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
if result[0] == self.tokenizer.bos_token_id:
result = result[1:]
if result[-1] == self.tokenizer.eos_id or (
hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
if result[-1] == self.tokenizer.eos_token_id or (
hasattr(self.tokenizer, "eot_token_id") and result[-1] == self.tokenizer.eot_token_id
):
result = result[:-1]
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
@ -352,8 +356,8 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
input_ids = torch.tensor(input_ids, dtype=torch.int32)
## image bound
start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
start_cond = (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.im_start_token)) | (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.slice_start_token))
end_cond = (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.im_end_token)) | (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.slice_end_token))
image_start_idx = torch.where(start_cond)[0]
image_start_idx += 1
@ -369,13 +373,13 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
)
## audio bound
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
audio_start_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.audio_start_token))[0]
audio_end_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.audio_end_token))[0]
assert len(audio_start_idx) == len(audio_end_idx)
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
spk_start_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.spk_bos_token))[0]
spk_end_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.spk_eos_token))[0]
assert len(spk_start_idx) == len(spk_end_idx)
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])

View File

@ -1,109 +0,0 @@
# coding=utf-8
# Copyright 2025 The OpenBMB Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import Qwen2TokenizerFast
class MiniCPM_o_2_6TokenizerFast(Qwen2TokenizerFast):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# image
self.im_start = "<image>"
self.im_end = "</image>"
self.ref_start = "<ref>"
self.ref_end = "</ref>"
self.box_start = "<box>"
self.box_end = "</box>"
self.quad_start = "<quad>"
self.quad_end = "</quad>"
self.slice_start = "<slice>"
self.slice_end = "</slice>"
self.im_id_start = "<image_id>"
self.im_id_end = "</image_id>"
# audio
self.audio_start = "<|audio_start|>"
self.audio_end = "<|audio_end|>"
self.spk_start = "<|spk_bos|>"
self.spk_end = "<|spk_eos|>"
self.tts_start = "<|tts_bos|>"
self.tts_end = "<|tts_eos|>"
@property
def eos_id(self):
return self.eos_token_id
@property
def bos_id(self):
return self.bos_token_id
@property
def unk_id(self):
return self.unk_token_id
@property
def im_start_id(self):
return self.convert_tokens_to_ids(self.im_start)
@property
def im_end_id(self):
return self.convert_tokens_to_ids(self.im_end)
@property
def slice_start_id(self):
return self.convert_tokens_to_ids(self.slice_start)
@property
def slice_end_id(self):
return self.convert_tokens_to_ids(self.slice_end)
@property
def im_id_start_id(self):
return self.convert_tokens_to_ids(self.im_id_start)
@property
def im_id_end_id(self):
return self.convert_tokens_to_ids(self.im_id_end)
@property
def audio_start_id(self):
return self.convert_tokens_to_ids(self.audio_start)
@property
def audio_end_id(self):
return self.convert_tokens_to_ids(self.audio_end)
@property
def spk_start_id(self):
return self.convert_tokens_to_ids(self.spk_start)
@property
def spk_end_id(self):
return self.convert_tokens_to_ids(self.spk_end)
@property
def tts_start_id(self):
return self.convert_tokens_to_ids(self.tts_start)
@property
def tts_end_id(self):
return self.convert_tokens_to_ids(self.tts_end)
@staticmethod
def escape(text: str) -> str:
return text
@staticmethod
def unescape(text: str) -> str:
return text