add preprocessing_num_workers to run_classification.py (#31586)

preprocessing_num_workers option to speedup preprocess
This commit is contained in:
Locke 2024-06-25 19:35:50 +08:00 committed by GitHub
parent fc689d75a0
commit e73a97a2b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -133,6 +133,10 @@ class DataTrainingArguments:
) )
}, },
) )
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
) )
@ -573,6 +577,7 @@ def main():
raw_datasets = raw_datasets.map( raw_datasets = raw_datasets.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )