mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Added max_sample_ arguments (#10551)
* reverted changes of logging and saving metrics * added max_sample arguments * fixed code * white space diff * reformetting code * reformatted code
This commit is contained in:
parent
917f104502
commit
dfd16af832
@ -114,6 +114,21 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
|
||||
block_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -346,6 +361,7 @@ def main():
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
@ -353,12 +369,26 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in tokenized_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = lm_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=lm_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=lm_datasets["validation"] if training_args.do_eval else None,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||
data_collator=default_data_collator,
|
||||
@ -377,24 +407,28 @@ def main():
|
||||
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
eval_output = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
perplexity = math.exp(eval_output["eval_loss"])
|
||||
results["perplexity"] = perplexity
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -146,6 +146,20 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -380,6 +394,7 @@ def main():
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
tokenized_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
@ -387,6 +402,20 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in tokenized_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
||||
@ -395,8 +424,8 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
@ -413,24 +442,28 @@ def main():
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
eval_output = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
perplexity = math.exp(eval_output["eval_loss"])
|
||||
results["perplexity"] = perplexity
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -143,6 +143,20 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -358,6 +372,7 @@ def main():
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
tokenized_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
@ -365,6 +380,20 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in tokenized_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
@ -376,8 +405,8 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
@ -394,24 +423,28 @@ def main():
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
eval_output = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
perplexity = math.exp(eval_output["eval_loss"])
|
||||
results["perplexity"] = perplexity
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -116,6 +116,20 @@ class DataTrainingArguments:
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.train_file is not None:
|
||||
@ -328,12 +342,31 @@ def main():
|
||||
# Un-flatten
|
||||
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if training_args.do_train:
|
||||
train_dataset = datasets["train"]
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = (
|
||||
@ -352,8 +385,8 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
@ -371,21 +404,25 @@ def main():
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -206,10 +206,14 @@ def main():
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
results.update(result)
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -118,6 +118,20 @@ class DataTrainingArguments:
|
||||
"be faster on GPU but will be slower on TPU)."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
@ -360,13 +374,23 @@ def main():
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_train:
|
||||
train_dataset = datasets["train"].map(
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
# We will select sample from whole data if agument is specified
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
# Create train feature from dataset
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_train_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_train_samples is not None:
|
||||
# Number of samples might increase during Feature Creation, We select only specified max samples
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
# Validation preprocessing
|
||||
def prepare_validation_features(examples):
|
||||
@ -411,13 +435,23 @@ def main():
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_eval:
|
||||
validation_dataset = datasets["validation"].map(
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_dataset.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_val_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
@ -462,7 +496,7 @@ def main():
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=validation_dataset if training_args.do_eval else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
@ -482,20 +516,25 @@ def main():
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
results = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -117,6 +117,20 @@ class DataTrainingArguments:
|
||||
"be faster on GPU but will be slower on TPU)."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
@ -373,13 +387,23 @@ def main():
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_train:
|
||||
train_dataset = datasets["train"].map(
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
# Select samples from Dataset, This will help to decrease processing time
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
# Create Training Features
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_train_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_train_samples is not None:
|
||||
# Select samples from dataset again since Feature Creation might increase number of features
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
# Validation preprocessing
|
||||
def prepare_validation_features(examples):
|
||||
@ -448,13 +472,23 @@ def main():
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_eval:
|
||||
validation_dataset = datasets["validation"].map(
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
# Selecting Eval Samples from Dataset
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
# Create Features from Eval Dataset
|
||||
eval_dataset = eval_dataset.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_val_samples is not None:
|
||||
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
@ -501,7 +535,7 @@ def main():
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=validation_dataset if training_args.do_eval else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
eval_examples=datasets["validation"] if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
@ -521,20 +555,26 @@ def main():
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
results = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
return results
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -601,7 +601,6 @@ def main():
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
@ -614,6 +613,7 @@ def main():
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# predict
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
|
||||
@ -640,8 +640,6 @@ def main():
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -64,6 +65,17 @@ def get_setup_file():
|
||||
return args.f
|
||||
|
||||
|
||||
def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
return results
|
||||
|
||||
|
||||
def is_cuda_and_apex_available():
|
||||
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
|
||||
return is_using_cuda and is_apex_available()
|
||||
@ -98,7 +110,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_glue.main()
|
||||
run_glue.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@ -130,7 +143,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_clm.main()
|
||||
run_clm.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@ -156,7 +170,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_mlm.main()
|
||||
run_mlm.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@ -185,7 +200,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_ner.main()
|
||||
run_ner.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
@ -214,7 +230,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_squad.main()
|
||||
run_squad.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
|
||||
@ -241,7 +258,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_swag.main()
|
||||
run_swag.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@ -288,8 +306,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
|
||||
run_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||
self.assertGreaterEqual(result["eval_rougeL"], 7)
|
||||
@ -323,5 +341,6 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
run_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
@ -89,6 +89,27 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the training data."}
|
||||
)
|
||||
@ -353,12 +374,41 @@ def main():
|
||||
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
return result
|
||||
|
||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
train_dataset = datasets["train"]
|
||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.task_name is not None or data_args.test_file is not None:
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets and "validation_matched" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||
if "test" not in datasets and "test_matched" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
@ -417,6 +467,10 @@ def main():
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
@ -425,7 +479,6 @@ def main():
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
eval_results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
@ -437,12 +490,13 @@ def main():
|
||||
eval_datasets.append(datasets["validation_mismatched"])
|
||||
|
||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
||||
trainer.log_metrics("eval", eval_result)
|
||||
trainer.save_metrics("eval", eval_result)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
eval_results.update(eval_result)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
@ -471,7 +525,6 @@ def main():
|
||||
else:
|
||||
item = label_list[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
return eval_results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
@ -247,10 +247,18 @@ def main():
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
trainer.log_metrics("eval", result)
|
||||
trainer.save_metrics("eval", result)
|
||||
results.update(result)
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -294,9 +294,16 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
result = trainer.evaluate()
|
||||
trainer.log_metrics("eval", result)
|
||||
trainer.save_metrics("eval", result)
|
||||
results.update(result)
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -73,6 +73,27 @@ class DataTrainingArguments:
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
|
||||
server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
|
||||
|
||||
@ -238,12 +259,23 @@ def main():
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
|
||||
)
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
|
||||
)
|
||||
if training_args.do_train:
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
@ -288,6 +320,10 @@ def main():
|
||||
model_path = None
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
@ -296,15 +332,15 @@ def main():
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
eval_results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
trainer.log_metrics("eval", eval_result)
|
||||
trainer.save_metrics("eval", eval_result)
|
||||
eval_results.update(eval_result)
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
||||
return eval_results
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -117,6 +117,27 @@ class DataTrainingArguments:
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
label_all_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -321,12 +342,44 @@ def main():
|
||||
tokenized_inputs["labels"] = labels
|
||||
return tokenized_inputs
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
@ -371,8 +424,8 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
@ -390,25 +443,31 @@ def main():
|
||||
metrics = train_result.metrics
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate()
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
trainer.log_metrics("eval", results)
|
||||
trainer.save_metrics("eval", results)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_dataset = tokenized_datasets["test"]
|
||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
@ -428,8 +487,6 @@ def main():
|
||||
for prediction in true_predictions:
|
||||
writer.write(" ".join(prediction) + "\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
|
Loading…
Reference in New Issue
Block a user