mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Simplify unk token (#12582)
* Base test * More test * Fix mistake * Add a docstring change * Add doc ignore * Simplify logic for unk token in Unigram tokenizers * Remove changes from otehr branch
This commit is contained in:
parent
deecdd4939
commit
0cc2dc2456
@ -662,22 +662,13 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
and tokenizer_json["model"]["end_of_word_suffix"] is not None
|
and tokenizer_json["model"]["end_of_word_suffix"] is not None
|
||||||
):
|
):
|
||||||
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
|
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
|
||||||
|
if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
|
||||||
|
kwargs["unk_token"] = unk_token
|
||||||
|
|
||||||
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
|
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
|
||||||
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
|
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
|
||||||
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
||||||
|
|
||||||
if unk_token is not None:
|
|
||||||
# For Unigram tokenizers we need to set back the unk id of the model (bug in Tokenizers?)
|
|
||||||
trained_tokenizer_json = json.loads(tokenizer.to_str())
|
|
||||||
vocab = trained_tokenizer_json["model"]["vocab"]
|
|
||||||
unk_id = 0
|
|
||||||
while unk_id < len(vocab) and vocab[unk_id][0] != unk_token:
|
|
||||||
unk_id += 1
|
|
||||||
if unk_id < len(vocab):
|
|
||||||
trained_tokenizer_json["model"]["unk_id"] = unk_id
|
|
||||||
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
|
|
||||||
|
|
||||||
if post_processor is not None:
|
if post_processor is not None:
|
||||||
trained_tokenizer_json = json.loads(tokenizer.to_str())
|
trained_tokenizer_json = json.loads(tokenizer.to_str())
|
||||||
# Almost done, we just have to adjust the token IDs in the post processor
|
# Almost done, we just have to adjust the token IDs in the post processor
|
||||||
|
Loading…
Reference in New Issue
Block a user