[Examples] Add automatic dataset splitting in language-modeling examples (#9133)

* replaced jnp.split + removing textual model inputs + ensuring warmup_steps > 0

* Add automatic dataset splitting in language-modeling examples
This commit is contained in:
Teven 2020-12-15 22:02:43 +01:00 committed by GitHub
parent e771749777
commit 2a7e8e1608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 16 deletions

View File

@ -113,6 +113,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
@ -188,6 +194,17 @@ def main():
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else: else:
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:

View File

@ -103,6 +103,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field( max_seq_length: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
@ -199,6 +205,17 @@ def main():
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else: else:
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:

View File

@ -134,6 +134,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field( max_seq_length: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
@ -413,7 +419,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
if samples_to_remove != 0: if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove] samples_idx = samples_idx[:-samples_to_remove]
sections_split = nb_samples // batch_size sections_split = nb_samples // batch_size
batch_idx = jnp.split(samples_idx, sections_split) batch_idx = np.split(samples_idx, sections_split)
return batch_idx return batch_idx
@ -473,6 +479,17 @@ if __name__ == "__main__":
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else: else:
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None:
@ -525,9 +542,9 @@ if __name__ == "__main__":
def tokenize_function(examples): def tokenize_function(examples):
# Remove empty lines # Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer( return tokenizer(
examples["text"], examples,
return_special_tokens_mask=True, return_special_tokens_mask=True,
padding=padding, padding=padding,
truncation=True, truncation=True,
@ -536,9 +553,10 @@ if __name__ == "__main__":
tokenized_datasets = datasets.map( tokenized_datasets = datasets.map(
tokenize_function, tokenize_function,
input_columns=[text_column_name],
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name], remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
@ -566,8 +584,9 @@ if __name__ == "__main__":
).create(model.params) ).create(model.params)
# Create learning rate scheduler # Create learning rate scheduler
# warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent.
lr_scheduler_fn = create_learning_rate_scheduler( lr_scheduler_fn = create_learning_rate_scheduler(
base_learning_rate=training_args.learning_rate, warmup_steps=training_args.warmup_steps base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1)
) )
# Create parallel version of the training and evaluation steps # Create parallel version of the training and evaluation steps
@ -606,13 +625,13 @@ if __name__ == "__main__":
epochs.write(f"Loss: {loss}") epochs.write(f"Loss: {loss}")
# ======================== Evaluating ============================== # ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["test"]) nb_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(nb_eval_samples) eval_samples_idx = jnp.arange(nb_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = [] eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx] samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16) model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward # Model forward

View File

@ -91,6 +91,12 @@ class DataTrainingArguments:
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
""" """
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field( validation_file: Optional[str] = field(
default=None, default=None,
@ -107,6 +113,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field( max_seq_length: Optional[int] = field(
default=None, default=None,
metadata={ metadata={
@ -203,15 +215,30 @@ def main():
# #
# In distributed training, the load_dataset function guarantee that only one local process can concurrently # In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset. # download the dataset.
data_files = {} if data_args.dataset_name is not None:
if data_args.train_file is not None: # Downloading and loading a dataset from the hub.
data_files["train"] = data_args.train_file datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if data_args.validation_file is not None: if "validation" not in datasets.keys():
data_files["validation"] = data_args.validation_file datasets["validation"] = load_dataset(
extension = data_args.train_file.split(".")[-1] data_args.dataset_name,
if extension == "txt": data_args.dataset_config_name,
extension = "text" split=f"train[:{data_args.validation_split_percentage}%]",
datasets = load_dataset(extension, data_files=data_files) )
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.

View File

@ -93,6 +93,12 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: int = field( max_seq_length: int = field(
default=512, default=512,
metadata={ metadata={
@ -196,6 +202,17 @@ def main():
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub. # Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else: else:
data_files = {} data_files = {}
if data_args.train_file is not None: if data_args.train_file is not None: