Remove unused test_dataset (#34516)

This commit is contained in:
Eon Kim 2024-11-05 23:01:25 +09:00 committed by GitHub
parent 663c851239
commit 45b0c7680c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -141,10 +141,6 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input testing data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=128,
metadata={
@ -190,9 +186,6 @@ class DataTrainingArguments:
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
dataset_name_mapping = {
@ -315,9 +308,6 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
dataset = load_dataset(
extension,
data_files=data_files,
@ -387,8 +377,6 @@ def main():
column_names = dataset["train"].column_names
elif training_args.do_eval:
column_names = dataset["validation"].column_names
elif training_args.do_predict:
column_names = dataset["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
@ -490,29 +478,6 @@ def main():
# Transform images on the fly as doing it on the whole dataset takes too much time.
eval_dataset.set_transform(transform_images)
if training_args.do_predict:
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
test_dataset = dataset["test"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
test_dataset = test_dataset.select(range(max_eval_samples))
test_dataset = test_dataset.filter(
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
)
test_dataset = test_dataset.map(
function=tokenize_captions,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[col for col in column_names if col != image_column],
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on test dataset",
)
# Transform images on the fly as doing it on the whole dataset takes too much time.
test_dataset.set_transform(transform_images)
# 8. Initialize our trainer
trainer = Trainer(
model=model,