Flax MLM: Allow validation split when loading dataset from local file (#12689)

* Allow validation split when loading dataset from local file

* Flax clm & t5, enable validation split for datasets loaded from local file
This commit is contained in:
fgaim 2021-07-20 13:38:25 +02:00 committed by GitHub
parent 6f8e367ae9
commit 66197adc98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 0 deletions

View File

@ -307,6 +307,20 @@ def main():
if extension == "txt":
extension = "text"
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

View File

@ -344,6 +344,20 @@ if __name__ == "__main__":
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

View File

@ -471,6 +471,19 @@ if __name__ == "__main__":
extension = "text"
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.