[Examples] Add automatic dataset splitting in language-modeling examples (#9133)

* replaced jnp.split + removing textual model inputs + ensuring warmup_steps > 0

* Add automatic dataset splitting in language-modeling examples
This commit is contained in:
Teven 2020-12-15 22:02:43 +01:00 committed by GitHub
parent e771749777
commit 2a7e8e1608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 16 deletions

View File

@ -113,6 +113,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
@ -188,6 +194,17 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:

View File

@ -103,6 +103,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
@ -199,6 +205,17 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:

View File

@ -134,6 +134,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
@ -413,7 +419,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = nb_samples // batch_size
batch_idx = jnp.split(samples_idx, sections_split)
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
@ -473,6 +479,17 @@ if __name__ == "__main__":
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:
@ -525,9 +542,9 @@ if __name__ == "__main__":
def tokenize_function(examples):
# Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples["text"],
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
@ -536,9 +553,10 @@ if __name__ == "__main__":
tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
@ -566,8 +584,9 @@ if __name__ == "__main__":
).create(model.params)
# Create learning rate scheduler
# warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
lr_scheduler_fn = create_learning_rate_scheduler(
base_learning_rate=training_args.learning_rate, warmup_steps=training_args.warmup_steps
base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1)
)
# Create parallel version of the training and evaluation steps
@ -606,13 +625,13 @@ if __name__ == "__main__":
epochs.write(f"Loss: {loss}")
# ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["test"])
nb_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(nb_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx]
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward

View File

@ -91,6 +91,12 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
@ -107,6 +113,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
@ -203,15 +215,30 @@ def main():
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.

View File

@ -93,6 +93,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: int = field(
default=512,
metadata={
@ -196,6 +202,17 @@ def main():
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None: