transformers/examples/tensorflow/language-modeling-tpu/train_unigram.py
Albert Villanova del Moral a14b055b65
Pass datasets trust_remote_code (#31406)
* Pass datasets trust_remote_code

* Pass trust_remote_code in more tests

* Add trust_remote_dataset_code arg to some tests

* Revert "Temporarily pin datasets upper version to fix CI"

This reverts commit b7672826ca.

* Pass trust_remote_code in librispeech_asr_dummy docstrings

* Revert "Pin datasets<2.20.0 for examples"

This reverts commit 833fc17a3e.

* Pass trust_remote_code to all examples

* Revert "Add trust_remote_dataset_code arg to some tests" to research_projects

* Pass trust_remote_code to tests

* Pass trust_remote_code to docstrings

* Fix flax examples tests requirements

* Pass trust_remote_dataset_code arg to tests

* Replace trust_remote_dataset_code with trust_remote_code in one example

* Fix duplicate trust_remote_code

* Replace args.trust_remote_dataset_code with args.trust_remote_code

* Replace trust_remote_dataset_code with trust_remote_code in parser

* Replace trust_remote_dataset_code with trust_remote_code in dataclasses

* Replace trust_remote_dataset_code with trust_remote_code arg
2024-06-17 17:29:13 +01:00

131 lines
4.2 KiB
Python

#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a Unigram tokenizer."""
import argparse
import logging
import datasets
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from transformers import AlbertTokenizerFast
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.")
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="Name of the training. Explore datasets at: hf.co/datasets.",
)
parser.add_argument(
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help=(
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
),
)
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="Batch size during training.",
)
parser.add_argument(
"--vocab_size",
type=int,
default=10048,
help="Size of the desired vocabulary.",
)
parser.add_argument(
"--limit",
default=None,
type=int,
help="Limit the number of shards (used for debugging).",
)
parser.add_argument(
"--export_to_hub",
action="store_true",
)
args = parser.parse_args()
return args
def main(args):
dataset = datasets.load_dataset(
args.dataset_name, args.dataset_config, split="train", trust_remote_code=args.trust_remote_code
)
if args.limit is not None:
max_train_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")
def batch_iterator():
for i in range(0, len(dataset), args.batch_size):
yield dataset[i : i + args.batch_size]["text"]
# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')])
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
# Prepare the trainer.
trainer = UnigramTrainer(
unk_token="<unk>",
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
vocab_size=args.vocab_size,
)
logger.info("Training the tokenizer.")
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
logger.info("Tokenizer training complete!")
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", cls_token_id),
("[SEP]", sep_token_id),
],
)
tokenizer.decoder = decoders.Metaspace()
if args.export_to_hub:
logger.info("Exporting the trained tokenizer to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")
if __name__ == "__main__":
args = parse_args()
main(args)