mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
change MiniCPMOTokenizerFast to Qwen2TokenizerFast & fit new WhisperAttention
This commit is contained in:
parent
49391cde01
commit
afae5f7767
@ -347,7 +347,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mgp-str", ("MgpstrTokenizer", None)),
|
||||
("minicpm_o_2_6", ("Qwen2Tokenizer", "MiniCPM_o_2_6TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("minicpm_o_2_6", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"minimax",
|
||||
(
|
||||
|
@ -49,7 +49,7 @@ from transformers.generation.logits_process import LogitsProcessor, TopKLogitsWa
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipPreTrainedModel
|
||||
from transformers.models.idefics2.modeling_idefics2 import Idefics2Encoder
|
||||
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder
|
||||
from transformers.models.whisper.modeling_whisper import WhisperAttention, WhisperConfig, WhisperEncoder
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -659,7 +659,7 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
|
||||
result_text = []
|
||||
for result in result_ids:
|
||||
result = result[result != 0]
|
||||
if result[0] == tokenizer.bos_id:
|
||||
if result[0] == tokenizer.bos_token_id:
|
||||
result = result[1:]
|
||||
if result[-1] in terminators:
|
||||
result = result[:-1]
|
||||
@ -1020,9 +1020,9 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
|
||||
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
|
||||
|
||||
if isinstance(answer, list):
|
||||
answer = [i.replace(tokenizer.tts_end, "") for i in answer]
|
||||
answer = [i.replace("<|tts_eos|>", "") for i in answer]
|
||||
else:
|
||||
answer = answer.replace(tokenizer.tts_end, "")
|
||||
answer = answer.replace("<|tts_eos|>", "")
|
||||
|
||||
if return_dict:
|
||||
return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr)
|
||||
@ -1074,7 +1074,7 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
|
||||
else:
|
||||
logger.error("Invalid content type:", c)
|
||||
|
||||
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
|
||||
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(cur_msgs)
|
||||
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
|
||||
if self.llm_generated:
|
||||
if self.llm_generate_completed:
|
||||
@ -1191,8 +1191,8 @@ class MiniCPM_o_2_6Model(MiniCPM_o_2_6PreTrainedModel):
|
||||
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
|
||||
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
|
||||
|
||||
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
|
||||
spk_start_idx = torch.where(input_ids[0] == tokenizer.convert_tokens_to_ids("<|spk_bos|>"))[0]
|
||||
spk_end_idx = torch.where(input_ids[0] == tokenizer.convert_tokens_to_ids("<|spk_eos|>"))[0]
|
||||
spk_bounds = [
|
||||
torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||||
] # List[Tensor], (1,2)
|
||||
@ -1872,7 +1872,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
@ -147,6 +147,10 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
|
||||
super().__init__(image_processor, feature_extractor, tokenizer)
|
||||
self.version = image_processor.version
|
||||
self.audio_start_token = "<|audio_start|>"
|
||||
self.audio_end_token = "<|audio_end|>"
|
||||
self.spk_bos_token = "<|spk_bos|>"
|
||||
self.spk_eos_token = "<|spk_eos|>"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -209,11 +213,11 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
total_unk_len = 0
|
||||
for _ in range(num_audio_chunks):
|
||||
unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
|
||||
place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
|
||||
place_holders += self.audio_start_token + self.image_processor.unk_token * unk_len + self.audio_end_token
|
||||
total_unk_len += unk_len
|
||||
audio_placeholder = place_holders
|
||||
else:
|
||||
audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
|
||||
audio_placeholder = self.audio_start_token + self.image_processor.unk_token * output_lens + self.audio_end_token
|
||||
|
||||
return audio_placeholder
|
||||
|
||||
@ -321,9 +325,9 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
result_text = []
|
||||
for result in output_ids:
|
||||
result = result[result != 0]
|
||||
if result[0] == self.tokenizer.bos_id:
|
||||
if result[0] == self.tokenizer.bos_token_id:
|
||||
result = result[1:]
|
||||
if result[-1] == self.tokenizer.eos_id:
|
||||
if result[-1] == self.tokenizer.eos_token_id:
|
||||
result = result[:-1]
|
||||
result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
|
||||
return result_text
|
||||
@ -337,10 +341,10 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
"""
|
||||
result = args[0]
|
||||
result = result[result != 0]
|
||||
if result[0] == self.tokenizer.bos_id:
|
||||
if result[0] == self.tokenizer.bos_token_id:
|
||||
result = result[1:]
|
||||
if result[-1] == self.tokenizer.eos_id or (
|
||||
hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
|
||||
if result[-1] == self.tokenizer.eos_token_id or (
|
||||
hasattr(self.tokenizer, "eot_token_id") and result[-1] == self.tokenizer.eot_token_id
|
||||
):
|
||||
result = result[:-1]
|
||||
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
|
||||
@ -352,8 +356,8 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
||||
|
||||
## image bound
|
||||
start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
|
||||
end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
|
||||
start_cond = (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.im_start_token)) | (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.slice_start_token))
|
||||
end_cond = (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.im_end_token)) | (input_ids == self.tokenizer.convert_tokens_to_ids(self.image_processor.slice_end_token))
|
||||
|
||||
image_start_idx = torch.where(start_cond)[0]
|
||||
image_start_idx += 1
|
||||
@ -369,13 +373,13 @@ class MiniCPM_o_2_6Processor(ProcessorMixin):
|
||||
)
|
||||
|
||||
## audio bound
|
||||
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
||||
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
||||
audio_start_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.audio_start_token))[0]
|
||||
audio_end_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.audio_end_token))[0]
|
||||
assert len(audio_start_idx) == len(audio_end_idx)
|
||||
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
||||
|
||||
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
||||
spk_start_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.spk_bos_token))[0]
|
||||
spk_end_idx = torch.where(input_ids == self.tokenizer.convert_tokens_to_ids(self.spk_eos_token))[0]
|
||||
assert len(spk_start_idx) == len(spk_end_idx)
|
||||
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||||
|
||||
|
@ -1,109 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The OpenBMB Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import Qwen2TokenizerFast
|
||||
|
||||
class MiniCPM_o_2_6TokenizerFast(Qwen2TokenizerFast):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# image
|
||||
self.im_start = "<image>"
|
||||
self.im_end = "</image>"
|
||||
self.ref_start = "<ref>"
|
||||
self.ref_end = "</ref>"
|
||||
self.box_start = "<box>"
|
||||
self.box_end = "</box>"
|
||||
self.quad_start = "<quad>"
|
||||
self.quad_end = "</quad>"
|
||||
self.slice_start = "<slice>"
|
||||
self.slice_end = "</slice>"
|
||||
self.im_id_start = "<image_id>"
|
||||
self.im_id_end = "</image_id>"
|
||||
|
||||
# audio
|
||||
self.audio_start = "<|audio_start|>"
|
||||
self.audio_end = "<|audio_end|>"
|
||||
self.spk_start = "<|spk_bos|>"
|
||||
self.spk_end = "<|spk_eos|>"
|
||||
self.tts_start = "<|tts_bos|>"
|
||||
self.tts_end = "<|tts_eos|>"
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.eos_token_id
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.bos_token_id
|
||||
|
||||
@property
|
||||
def unk_id(self):
|
||||
return self.unk_token_id
|
||||
|
||||
@property
|
||||
def im_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.im_start)
|
||||
|
||||
@property
|
||||
def im_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.im_end)
|
||||
|
||||
@property
|
||||
def slice_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.slice_start)
|
||||
|
||||
@property
|
||||
def slice_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.slice_end)
|
||||
|
||||
@property
|
||||
def im_id_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.im_id_start)
|
||||
|
||||
@property
|
||||
def im_id_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.im_id_end)
|
||||
|
||||
@property
|
||||
def audio_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.audio_start)
|
||||
|
||||
@property
|
||||
def audio_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.audio_end)
|
||||
|
||||
@property
|
||||
def spk_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.spk_start)
|
||||
|
||||
@property
|
||||
def spk_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.spk_end)
|
||||
|
||||
@property
|
||||
def tts_start_id(self):
|
||||
return self.convert_tokens_to_ids(self.tts_start)
|
||||
|
||||
@property
|
||||
def tts_end_id(self):
|
||||
return self.convert_tokens_to_ids(self.tts_end)
|
||||
|
||||
@staticmethod
|
||||
def escape(text: str) -> str:
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
return text
|
Loading…
Reference in New Issue
Block a user