diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index fb9d6e27d4d..4e07441c14b 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -113,6 +113,12 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, @@ -188,6 +194,17 @@ def main(): if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) else: data_files = {} if data_args.train_file is not None: diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 664128eaf9f..4ae0b622695 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -103,6 +103,12 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) max_seq_length: Optional[int] = field( default=None, metadata={ @@ -199,6 +205,17 @@ def main(): if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) else: data_files = {} if data_args.train_file is not None: diff --git a/examples/language-modeling/run_mlm_flax.py b/examples/language-modeling/run_mlm_flax.py index 0c2f0622a39..5fe4aefce83 100644 --- a/examples/language-modeling/run_mlm_flax.py +++ b/examples/language-modeling/run_mlm_flax.py @@ -134,6 +134,12 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) max_seq_length: Optional[int] = field( default=None, metadata={ @@ -413,7 +419,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar if samples_to_remove != 0: samples_idx = samples_idx[:-samples_to_remove] sections_split = nb_samples // batch_size - batch_idx = jnp.split(samples_idx, sections_split) + batch_idx = np.split(samples_idx, sections_split) return batch_idx @@ -473,6 +479,17 @@ if __name__ == "__main__": if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) else: data_files = {} if data_args.train_file is not None: @@ -525,9 +542,9 @@ if __name__ == "__main__": def tokenize_function(examples): # Remove empty lines - examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] + examples = [line for line in examples if len(line) > 0 and not line.isspace()] return tokenizer( - examples["text"], + examples, return_special_tokens_mask=True, padding=padding, truncation=True, @@ -536,9 +553,10 @@ if __name__ == "__main__": tokenized_datasets = datasets.map( tokenize_function, + input_columns=[text_column_name], batched=True, num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], + remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) @@ -566,8 +584,9 @@ if __name__ == "__main__": ).create(model.params) # Create learning rate scheduler + # warmup_steps = 0 causes the Flax optimizer to return NaNs; warmup_steps = 1 is functionally equivalent. lr_scheduler_fn = create_learning_rate_scheduler( - base_learning_rate=training_args.learning_rate, warmup_steps=training_args.warmup_steps + base_learning_rate=training_args.learning_rate, warmup_steps=min(training_args.warmup_steps, 1) ) # Create parallel version of the training and evaluation steps @@ -606,13 +625,13 @@ if __name__ == "__main__": epochs.write(f"Loss: {loss}") # ======================== Evaluating ============================== - nb_eval_samples = len(tokenized_datasets["test"]) + nb_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(nb_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): - samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx] + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/language-modeling/run_mlm_wwm.py index 4686a64047e..228205ec9a6 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/language-modeling/run_mlm_wwm.py @@ -91,6 +91,12 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, @@ -107,6 +113,12 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) max_seq_length: Optional[int] = field( default=None, metadata={ @@ -203,15 +215,30 @@ def main(): # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - datasets = load_dataset(extension, data_files=data_files) + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index cb64716d262..4b603973bd2 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -93,6 +93,12 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) max_seq_length: int = field( default=512, metadata={ @@ -196,6 +202,17 @@ def main(): if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + ) else: data_files = {} if data_args.train_file is not None: