Remove task arg in load_dataset in image-classification example (#28408)

* Remove `task` arg in `load_dataset` in image-classification example

* Manage case where "train" is not in dataset

* Add new args to manage image and label column names

* Similar to audio-classification example

* Fix README

* Update tests
This commit is contained in:
regisss 2024-01-16 08:04:08 +01:00 committed by GitHub
parent edb170238f
commit 0cdcd7a2b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 18 deletions

View File

@ -41,6 +41,7 @@ python run_image_classification.py \
--dataset_name beans \
--output_dir ./beans_outputs/ \
--remove_unused_columns False \
--label_column_name labels \
--do_train \
--do_eval \
--push_to_hub \
@ -197,7 +198,7 @@ accelerate test
that will check everything is ready for training. Finally, you can launch training with
```bash
accelerate launch run_image_classification_trainer.py
accelerate launch run_image_classification_no_trainer.py --image_column_name img
```
This command is the same and will work for:

View File

@ -111,6 +111,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)
def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
@ -175,12 +183,6 @@ class ModelArguments:
)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@ -255,7 +257,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
@ -268,9 +269,27 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
@ -280,7 +299,7 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
@ -354,13 +373,15 @@ def main():
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch
def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch
if training_args.do_train:

View File

@ -189,6 +189,18 @@ def parse_args():
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
parser.add_argument(
"--image_column_name",
type=str,
default="image",
help="The name of the dataset column containing the image data. Defaults to 'image'.",
)
parser.add_argument(
"--label_column_name",
type=str,
default="label",
help="The name of the dataset column containing the labels. Defaults to 'label'.",
)
args = parser.parse_args()
# Sanity checks
@ -272,7 +284,7 @@ def main():
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(args.dataset_name, task="image-classification")
dataset = load_dataset(args.dataset_name)
else:
data_files = {}
if args.train_dir is not None:
@ -282,11 +294,24 @@ def main():
dataset = load_dataset(
"imagefolder",
data_files=data_files,
task="image-classification",
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)
# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
@ -296,7 +321,7 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[args.label_column_name].names
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}
@ -355,12 +380,16 @@ def main():
def preprocess_train(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch
def preprocess_val(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
example_batch["pixel_values"] = [
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
]
return example_batch
with accelerator.main_process_first():
@ -376,7 +405,7 @@ def main():
# DataLoaders creation:
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
labels = torch.tensor([example[args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
train_dataloader = DataLoader(

View File

@ -322,6 +322,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--output_dir {tmp_dir}
--with_tracking
--checkpointing_steps 1
--label_column_name labels
""".split()
run_command(self._launch_args + testargs)

View File

@ -398,6 +398,7 @@ class ExamplesTests(TestCasePlus):
--max_steps 10
--train_val_split 0.1
--seed 42
--label_column_name labels
""".split()
if is_torch_fp16_available_on_device(torch_device):