mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 09:40:06 +06:00
[Examples/TensorFlow] minor refactoring to allow compatible datasets to work (#22879)
minor refactoring to allow compatible datasets to work.
This commit is contained in:
parent
10dd3a7d1c
commit
4116d1ec75
@ -33,6 +33,15 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset."
|
description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
default="wikitext",
|
||||||
|
help="Name of the training. Explore datasets at: hf.co/datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer_name_or_path",
|
"--tokenizer_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
@ -96,11 +105,11 @@ def get_serialized_examples(tokenized_data):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
wikitext = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split=args.split)
|
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split=args.split)
|
||||||
|
|
||||||
if args.limit is not None:
|
if args.limit is not None:
|
||||||
max_samples = min(len(wikitext), args.limit)
|
max_samples = min(len(dataset), args.limit)
|
||||||
wikitext = wikitext.select(range(max_samples))
|
dataset = dataset.select(range(max_samples))
|
||||||
print(f"Limiting the dataset to {args.limit} entries.")
|
print(f"Limiting the dataset to {args.limit} entries.")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
|
||||||
@ -119,7 +128,7 @@ def main(args):
|
|||||||
|
|
||||||
# Tokenize the whole dataset at once.
|
# Tokenize the whole dataset at once.
|
||||||
tokenize_fn = tokenize_function(tokenizer)
|
tokenize_fn = tokenize_function(tokenizer)
|
||||||
wikitext_tokenized = wikitext.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])
|
dataset_tokenized = dataset.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])
|
||||||
|
|
||||||
# We need to concatenate all our texts together, and then split the result
|
# We need to concatenate all our texts together, and then split the result
|
||||||
# into chunks of a fixed size, which we will call block_size. To do this, we
|
# into chunks of a fixed size, which we will call block_size. To do this, we
|
||||||
@ -144,14 +153,14 @@ def main(args):
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
grouped_dataset = wikitext_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)
|
grouped_dataset = dataset_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)
|
||||||
|
|
||||||
shard_count = 0
|
shard_count = 0
|
||||||
total_records = 0
|
total_records = 0
|
||||||
for shard in range(0, len(grouped_dataset), args.shard_size):
|
for shard in range(0, len(grouped_dataset), args.shard_size):
|
||||||
dataset_snapshot = grouped_dataset[shard : shard + args.shard_size]
|
dataset_snapshot = grouped_dataset[shard : shard + args.shard_size]
|
||||||
records_containing = len(dataset_snapshot["input_ids"])
|
records_containing = len(dataset_snapshot["input_ids"])
|
||||||
filename = os.path.join(split_dir, f"wikitext-{shard_count}-{records_containing}.tfrecord")
|
filename = os.path.join(split_dir, f"dataset-{shard_count}-{records_containing}.tfrecord")
|
||||||
serialized_examples = get_serialized_examples(dataset_snapshot)
|
serialized_examples = get_serialized_examples(dataset_snapshot)
|
||||||
|
|
||||||
with tf.io.TFRecordWriter(filename) as out_file:
|
with tf.io.TFRecordWriter(filename) as out_file:
|
||||||
|
@ -69,16 +69,16 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
|
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
|
||||||
|
|
||||||
if args.limit is not None:
|
if args.limit is not None:
|
||||||
max_train_samples = min(len(wikitext), args.limit)
|
max_train_samples = min(len(dataset), args.limit)
|
||||||
wikitext = wikitext.select(range(max_train_samples))
|
dataset = dataset.select(range(max_train_samples))
|
||||||
logger.info(f"Limiting the dataset to {args.limit} entries.")
|
logger.info(f"Limiting the dataset to {args.limit} entries.")
|
||||||
|
|
||||||
def batch_iterator():
|
def batch_iterator():
|
||||||
for i in range(0, len(wikitext), args.batch_size):
|
for i in range(0, len(dataset), args.batch_size):
|
||||||
yield wikitext[i : i + args.batch_size]["text"]
|
yield dataset[i : i + args.batch_size]["text"]
|
||||||
|
|
||||||
# Prepare the tokenizer.
|
# Prepare the tokenizer.
|
||||||
tokenizer = Tokenizer(Unigram())
|
tokenizer = Tokenizer(Unigram())
|
||||||
@ -111,7 +111,7 @@ def main(args):
|
|||||||
if args.export_to_hub:
|
if args.export_to_hub:
|
||||||
logger.info("Exporting the trained tokenzier to Hub.")
|
logger.info("Exporting the trained tokenzier to Hub.")
|
||||||
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
|
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
|
||||||
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
|
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user