mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Herbert polish model (#7798)
* HerBERT transformer model for Polish language understanding. * HerbertTokenizerFast generated with HerbertConverter * Herbert base and large model cards * Herbert model cards with tags * Herbert tensorflow models * Herbert model tests based on Bert test suit * src/transformers/tokenization_herbert.py edited online with Bitbucket * src/transformers/tokenization_herbert.py edited online with Bitbucket * docs/source/model_doc/herbert.rst edited online with Bitbucket * Herbert tokenizer tests and bug fixes * src/transformers/configuration_herbert.py edited online with Bitbucket * Copyrights and tests for TFHerbertModel * model_cards/allegro/herbert-base-cased/README.md edited online with Bitbucket * model_cards/allegro/herbert-large-cased/README.md edited online with Bitbucket * Bug fixes after testing * Reformat modified_only_fixup * Proper order of configuration * Herbert proper documentation formatting * Formatting with make modified_only_fixup * Dummies fixed * Adding missing models to documentation * Removing HerBERT model as it is a simple extension of BERT * Update model_cards/allegro/herbert-base-cased/README.md Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Update model_cards/allegro/herbert-large-cased/README.md Co-authored-by: Julien Chaumond <chaumond@gmail.com> * HerbertTokenizer deprecated configuration removed Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
99898dcd27
commit
7b13bd01df
51
model_cards/allegro/herbert-base-cased/README.md
Normal file
51
model_cards/allegro/herbert-base-cased/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
---
|
||||
language: pl
|
||||
tags:
|
||||
- herbert
|
||||
license: cc-by-sa-4.0
|
||||
---
|
||||
|
||||
# HerBERT
|
||||
**[HerBERT](https://en.wikipedia.org/wiki/Zbigniew_Herbert)** is a BERT-based Language Model trained on Polish Corpora
|
||||
using MLM and SSO objectives with dynamic masking of whole words.
|
||||
Model training and experiments were conducted with [transformers](https://github.com/huggingface/transformers) in version 2.9.
|
||||
|
||||
## Tokenizer
|
||||
The training dataset was tokenized into subwords using ``CharBPETokenizer`` a character level byte-pair encoding with
|
||||
a vocabulary size of 50k tokens. The tokenizer itself was trained with a [tokenizers](https://github.com/huggingface/tokenizers) library.
|
||||
We kindly encourage you to use the **Fast** version of tokenizer, namely ``HerbertTokenizerFast``.
|
||||
|
||||
## HerBERT usage
|
||||
|
||||
|
||||
Example code:
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")
|
||||
model = AutoModel.from_pretrained("allegro/herbert-base-cased")
|
||||
|
||||
output = model(
|
||||
**tokenizer.batch_encode_plus(
|
||||
[
|
||||
(
|
||||
"A potem szedł środkiem drogi w kurzawie, bo zamiatał nogami, ślepy dziad prowadzony przez tłustego kundla na sznurku.",
|
||||
"A potem leciał od lasu chłopak z butelką, ale ten ujrzawszy księdza przy drodze okrążył go z dala i biegł na przełaj pól do karczmy."
|
||||
)
|
||||
],
|
||||
padding='longest',
|
||||
add_special_tokens=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## License
|
||||
CC BY-SA 4.0
|
||||
|
||||
|
||||
## Authors
|
||||
Model was trained by **Allegro Machine Learning Research** team.
|
||||
|
||||
You can contact us at: <a href="mailto:klejbenchmark@allegro.pl">klejbenchmark@allegro.pl</a>
|
50
model_cards/allegro/herbert-large-cased/README.md
Normal file
50
model_cards/allegro/herbert-large-cased/README.md
Normal file
@ -0,0 +1,50 @@
|
||||
---
|
||||
language: pl
|
||||
tags:
|
||||
- herbert
|
||||
license: cc-by-sa-4.0
|
||||
---
|
||||
# HerBERT
|
||||
**[HerBERT](https://en.wikipedia.org/wiki/Zbigniew_Herbert)** is a BERT-based Language Model trained on Polish Corpora
|
||||
using MLM and SSO objectives with dynamic masking of whole words.
|
||||
Model training and experiments were conducted with [transformers](https://github.com/huggingface/transformers) in version 2.9.
|
||||
|
||||
## Tokenizer
|
||||
The training dataset was tokenized into subwords using ``CharBPETokenizer`` a character level byte-pair encoding with
|
||||
a vocabulary size of 50k tokens. The tokenizer itself was trained with a [tokenizers](https://github.com/huggingface/tokenizers) library.
|
||||
We kindly encourage you to use the **Fast** version of tokenizer, namely ``HerbertTokenizerFast``.
|
||||
|
||||
## HerBERT usage
|
||||
|
||||
|
||||
Example code:
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-large-cased")
|
||||
model = AutoModel.from_pretrained("allegro/herbert-large-cased")
|
||||
|
||||
output = model(
|
||||
**tokenizer.batch_encode_plus(
|
||||
[
|
||||
(
|
||||
"A potem szedł środkiem drogi w kurzawie, bo zamiatał nogami, ślepy dziad prowadzony przez tłustego kundla na sznurku.",
|
||||
"A potem leciał od lasu chłopak z butelką, ale ten ujrzawszy księdza przy drodze okrążył go z dala i biegł na przełaj pól do karczmy."
|
||||
)
|
||||
],
|
||||
padding='longest',
|
||||
add_special_tokens=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## License
|
||||
CC BY-SA 4.0
|
||||
|
||||
|
||||
## Authors
|
||||
Model was trained by **Allegro Machine Learning Research** team.
|
||||
|
||||
You can contact us at: <a href="mailto:klejbenchmark@allegro.pl">klejbenchmark@allegro.pl</a>
|
@ -177,6 +177,7 @@ from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_fsmt import FSMTTokenizer
|
||||
from .tokenization_funnel import FunnelTokenizer, FunnelTokenizerFast
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_herbert import HerbertTokenizer, HerbertTokenizerFast
|
||||
from .tokenization_layoutlm import LayoutLMTokenizer, LayoutLMTokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_lxmert import LxmertTokenizer, LxmertTokenizerFast
|
||||
|
@ -227,6 +227,37 @@ class GPT2Converter(Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class HerbertConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer_info_str = "#version:"
|
||||
token_suffix = "</w>"
|
||||
|
||||
vocab = self.original_tokenizer.encoder
|
||||
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
||||
if tokenizer_info_str in merges[0][0]:
|
||||
merges = merges[1:]
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
vocab,
|
||||
merges,
|
||||
dropout=None,
|
||||
unk_token=self.original_tokenizer.unk_token,
|
||||
end_of_word_suffix=token_suffix,
|
||||
)
|
||||
)
|
||||
|
||||
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
||||
tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
|
||||
tokenizer.post_processor = processors.BertProcessing(
|
||||
sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
|
||||
cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class RobertaConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
ot = self.original_tokenizer
|
||||
@ -550,6 +581,7 @@ CONVERTERS = {
|
||||
"ElectraTokenizer": BertConverter,
|
||||
"FunnelTokenizer": FunnelConverter,
|
||||
"GPT2Tokenizer": GPT2Converter,
|
||||
"HerbertTokenizer": HerbertConverter,
|
||||
"LxmertTokenizer": BertConverter,
|
||||
"MBartTokenizer": MBartConverter,
|
||||
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
||||
|
197
src/transformers/tokenization_herbert.py
Normal file
197
src/transformers/tokenization_herbert.py
Normal file
@ -0,0 +1,197 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors, Allegro.pl, Facebook Inc. and the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
|
||||
class HerbertTokenizer(XLMTokenizer):
|
||||
"""
|
||||
Construct a BPE tokenizer for HerBERT.
|
||||
|
||||
Peculiarities:
|
||||
|
||||
- uses BERT's pre-tokenizer: BaseTokenizer splits tokens on spaces, and also on punctuation.
|
||||
Each occurence of a punctuation character will be treated separately.
|
||||
|
||||
- Such pretokenized input is BPE subtokenized
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.XLMTokenizer` which contains most of the methods. Users
|
||||
should refer to the superclass for more information regarding methods.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
kwargs["cls_token"] = "<s>"
|
||||
kwargs["unk_token"] = "<unk>"
|
||||
kwargs["pad_token"] = "<pad>"
|
||||
kwargs["mask_token"] = "<mask>"
|
||||
kwargs["sep_token"] = "</s>"
|
||||
kwargs["do_lowercase_and_remove_accent"] = False
|
||||
kwargs["additional_special_tokens"] = []
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.bert_pre_tokenizer = BasicTokenizer(
|
||||
do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False
|
||||
)
|
||||
|
||||
def _tokenize(self, text):
|
||||
|
||||
pre_tokens = self.bert_pre_tokenizer.tokenize(text)
|
||||
|
||||
split_tokens = []
|
||||
for token in pre_tokens:
|
||||
if token:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(" ")])
|
||||
|
||||
return split_tokens
|
||||
|
||||
|
||||
class HerbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Construct a "Fast" BPE tokenizer for HerBERT (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
Peculiarities:
|
||||
|
||||
- uses BERT's pre-tokenizer: BertPreTokenizer splits tokens on spaces, and also on punctuation.
|
||||
Each occurence of a punctuation character will be treated separately.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
||||
should refer to the superclass for more information regarding methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (:obj:`str`):
|
||||
Path to the merges file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
slow_tokenizer_class = HerbertTokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, **kwargs):
|
||||
|
||||
kwargs["cls_token"] = "<s>"
|
||||
kwargs["unk_token"] = "<unk>"
|
||||
kwargs["pad_token"] = "<pad>"
|
||||
kwargs["mask_token"] = "<mask>"
|
||||
kwargs["sep_token"] = "</s>"
|
||||
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
An HerBERT, like BERT sequence has the following format:
|
||||
|
||||
- single sequence: ``<s> X </s>``
|
||||
- pair of sequences: ``<s> A </s> B </s>``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
if token_ids_1 is None:
|
||||
return cls + token_ids_0 + sep
|
||||
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
HerBERT, like BERT sequence pair mask has the following format:
|
||||
|
||||
::
|
||||
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
122
tests/test_tokenization_herbert.py
Normal file
122
tests/test_tokenization_herbert.py
Normal file
@ -0,0 +1,122 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors, Allegro.pl and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.tokenization_herbert import VOCAB_FILES_NAMES, HerbertTokenizer, HerbertTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = HerbertTokenizer
|
||||
rust_tokenizer_class = HerbertTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = [
|
||||
"<s>",
|
||||
"</s>",
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"w</w>",
|
||||
"r</w>",
|
||||
"t</w>",
|
||||
"lo",
|
||||
"low",
|
||||
"er</w>",
|
||||
"low</w>",
|
||||
"lowest</w>",
|
||||
"newer</w>",
|
||||
"wider</w>",
|
||||
",</w>",
|
||||
"<unk>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(vocab_file=self.vocab_file, merges_file=self.merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [16, 17, 23]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
sequence = "lower,newer"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("allegro/herbert-base-cased")
|
||||
|
||||
text = tokenizer.encode("konstruowanie sekwencji", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("konstruowanie wielu sekwencji", add_special_tokens=False)
|
||||
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == [0] + text + [2]
|
||||
assert encoded_pair == [0] + text + [2] + text_2 + [2]
|
Loading…
Reference in New Issue
Block a user