diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index f22da85934a..dfb781c7a4a 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -157,7 +157,7 @@ class DataTrainingArguments: metadata={"help": "The number of processes to use for the preprocessing."}, ) keep_linebreaks: bool = field( - default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) def __post_init__(self): @@ -305,6 +305,7 @@ def main(): ) else: data_files = {} + dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: @@ -312,22 +313,23 @@ def main(): extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" - dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, - keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, + **dataset_args, ) dataset["train"] = load_dataset( extension, - keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, + **dataset_args, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index c5e872a4ece..950734e3ee9 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -173,7 +173,7 @@ class DataTrainingArguments: metadata={"help": "The number of processes to use for the preprocessing."}, ) keep_linebreaks: bool = field( - default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) def __post_init__(self): @@ -269,6 +269,7 @@ def main(): ) else: data_files = {} + dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: @@ -280,22 +281,23 @@ def main(): ) if extension == "txt": extension = "text" - raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args) # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( extension, - keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, + **dataset_args, ) raw_datasets["train"] = load_dataset( extension, - keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, + **dataset_args, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 723c64c3d02..fd69abe4a41 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -174,7 +174,7 @@ def parse_args(): "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" ) parser.add_argument( - "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files." + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." ) args = parser.parse_args() @@ -248,6 +248,7 @@ def main(): ) else: data_files = {} + dataset_args = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: @@ -255,20 +256,21 @@ def main(): extension = args.train_file.split(".")[-1] if extension == "txt": extension = "text" - raw_datasets = load_dataset(extension, data_files=data_files) + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( extension, - keep_linebreaks=not args.no_keep_linebreaks, data_files=data_files, split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, ) raw_datasets["train"] = load_dataset( extension, - keep_linebreaks=not args.no_keep_linebreaks, data_files=data_files, split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index c9e5bc05367..8c9f211a3ce 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -187,7 +187,7 @@ class DataTrainingArguments: }, ) keep_linebreaks: bool = field( - default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} ) def __post_init__(self): @@ -321,6 +321,7 @@ def main(): ) else: data_files = {} + dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: @@ -328,7 +329,8 @@ def main(): extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" - raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files) + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # endregion