mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove unused test_dataset (#34516)
This commit is contained in:
parent
663c851239
commit
45b0c7680c
@ -141,10 +141,6 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input testing data file (a jsonlines file)."},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@ -190,9 +186,6 @@ class DataTrainingArguments:
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
|
||||
dataset_name_mapping = {
|
||||
@ -315,9 +308,6 @@ def main():
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
dataset = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
@ -387,8 +377,6 @@ def main():
|
||||
column_names = dataset["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = dataset["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
column_names = dataset["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
return
|
||||
@ -490,29 +478,6 @@ def main():
|
||||
# Transform images on the fly as doing it on the whole dataset takes too much time.
|
||||
eval_dataset.set_transform(transform_images)
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in dataset:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = dataset["test"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(test_dataset), data_args.max_eval_samples)
|
||||
test_dataset = test_dataset.select(range(max_eval_samples))
|
||||
|
||||
test_dataset = test_dataset.filter(
|
||||
filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers
|
||||
)
|
||||
test_dataset = test_dataset.map(
|
||||
function=tokenize_captions,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=[col for col in column_names if col != image_column],
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on test dataset",
|
||||
)
|
||||
|
||||
# Transform images on the fly as doing it on the whole dataset takes too much time.
|
||||
test_dataset.set_transform(transform_images)
|
||||
|
||||
# 8. Initialize our trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
Loading…
Reference in New Issue
Block a user