mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add CpmTokenizerFast (#12938)
* Add CpmTokenizerFast * Fix isort * Overwrite _batch_encode_plus
This commit is contained in:
parent
e2d22eef14
commit
fd0255b41d
@ -177,6 +177,7 @@ if is_tokenizers_available():
|
||||
from ..big_bird.tokenization_big_bird_fast import BigBirdTokenizerFast
|
||||
from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast
|
||||
from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast
|
||||
from ..cpm.tokenization_cpm_fast import CpmTokenizerFast
|
||||
from ..deberta.tokenization_deberta_fast import DebertaTokenizerFast
|
||||
from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast
|
||||
from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast
|
||||
@ -212,6 +213,7 @@ else:
|
||||
BigBirdTokenizerFast = None
|
||||
CamembertTokenizerFast = None
|
||||
ConvBertTokenizerFast = None
|
||||
CpmTokenizerFast = None
|
||||
DebertaTokenizerFast = None
|
||||
DistilBertTokenizerFast = None
|
||||
DPRQuestionEncoderTokenizerFast = None
|
||||
@ -308,6 +310,7 @@ NO_CONFIG_TOKENIZER = [
|
||||
BertweetTokenizer,
|
||||
ByT5Tokenizer,
|
||||
CpmTokenizer,
|
||||
CpmTokenizerFast,
|
||||
HerbertTokenizer,
|
||||
HerbertTokenizerFast,
|
||||
PhobertTokenizer,
|
||||
|
@ -18,16 +18,24 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule
|
||||
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"tokenization_cpm": ["CpmTokenizer"],
|
||||
}
|
||||
_import_structure = {}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["tokenization_cpm"] = ["CpmTokenizer"]
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_cpm_fast"] = ["CpmTokenizerFast"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tokenization_cpm import CpmTokenizer
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_cpm import CpmTokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_cpm_fast import CpmTokenizerFast
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
@ -92,7 +92,7 @@ class CpmTokenizer(XLNetTokenizer):
|
||||
import jieba
|
||||
except ModuleNotFoundError as error:
|
||||
raise error.__class__(
|
||||
"You need to install jieba to use CpmTokenizer."
|
||||
"You need to install jieba to use CpmTokenizer or CpmTokenizerFast."
|
||||
"See https://pypi.org/project/jieba/ for installation."
|
||||
)
|
||||
self.jieba = jieba
|
||||
|
114
src/transformers/models/cpm/tokenization_cpm_fast.py
Normal file
114
src/transformers/models/cpm/tokenization_cpm_fast.py
Normal file
@ -0,0 +1,114 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
from ...utils import logging
|
||||
from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"TsinghuaAI/CPM-Generate": "https://huggingface.co/TsinghuaAI/CPM-Generate/resolve/main/spiece.model",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"TsinghuaAI/CPM-Generate": "https://huggingface.co/TsinghuaAI/CPM-Generate/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CpmTokenizerFast(XLNetTokenizerFast):
|
||||
"""Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Construct a CPM tokenizer. Based on `Jieba <https://pypi.org/project/jieba/>` and `SentencePiece
|
||||
<https://github.com/google/sentencepiece>`__.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main
|
||||
methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a .spm extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to lowercase the input when tokenizing.
|
||||
remove_space (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to strip the text when tokenizing (removing excess spaces before and after the string).
|
||||
keep_accents (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to keep accents when tokenizing.
|
||||
bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier
|
||||
token.
|
||||
|
||||
.. note::
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the beginning
|
||||
of sequence. The token used is the :obj:`cls_token`.
|
||||
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
.. note::
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the end of
|
||||
sequence. The token used is the :obj:`sep_token`.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be
|
||||
this token instead.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"<sep>"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
||||
for sequence classification or for a text and a question for question answering. It is also used as the
|
||||
last token of a sequence built with special tokens.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<cls>"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole
|
||||
sequence instead of per-token classification). It is the first token of the sequence when built with
|
||||
special tokens.
|
||||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<eop>", "<eod>"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
|
||||
Attributes:
|
||||
sp_model (:obj:`SentencePieceProcessor`):
|
||||
The `SentencePiece` processor that is used for every conversion (string, tokens and IDs).
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
import jieba
|
||||
except ModuleNotFoundError as error:
|
||||
raise error.__class__(
|
||||
"You need to install jieba to use CpmTokenizer or CpmTokenizerFast."
|
||||
"See https://pypi.org/project/jieba/ for installation."
|
||||
)
|
||||
self.jieba = jieba
|
||||
self.translator = str.maketrans(" \n", "\u2582\u2583")
|
||||
|
||||
def _batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
|
||||
batch_text_or_text_pairs = [
|
||||
" ".join([x.translate(self.translator) for x in self.jieba.cut(text, cut_all=False)])
|
||||
for text in batch_text_or_text_pairs
|
||||
]
|
||||
return super()._batch_encode_plus(batch_text_or_text_pairs, *args, **kwargs)
|
||||
|
||||
def _decode(self, *args, **kwargs):
|
||||
text = super()._decode(*args, **kwargs)
|
||||
text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n")
|
||||
return text
|
Loading…
Reference in New Issue
Block a user