mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Remove task
arg in load_dataset
in image-classification example (#28408)
* Remove `task` arg in `load_dataset` in image-classification example * Manage case where "train" is not in dataset * Add new args to manage image and label column names * Similar to audio-classification example * Fix README * Update tests
This commit is contained in:
parent
edb170238f
commit
0cdcd7a2b3
@ -41,6 +41,7 @@ python run_image_classification.py \
|
||||
--dataset_name beans \
|
||||
--output_dir ./beans_outputs/ \
|
||||
--remove_unused_columns False \
|
||||
--label_column_name labels \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--push_to_hub \
|
||||
@ -197,7 +198,7 @@ accelerate test
|
||||
that will check everything is ready for training. Finally, you can launch training with
|
||||
|
||||
```bash
|
||||
accelerate launch run_image_classification_trainer.py
|
||||
accelerate launch run_image_classification_no_trainer.py --image_column_name img
|
||||
```
|
||||
|
||||
This command is the same and will work for:
|
||||
|
@ -111,6 +111,14 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
image_column_name: str = field(
|
||||
default="image",
|
||||
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
|
||||
)
|
||||
label_column_name: str = field(
|
||||
default="label",
|
||||
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
|
||||
@ -175,12 +183,6 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["labels"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
@ -255,7 +257,6 @@ def main():
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
task="image-classification",
|
||||
token=model_args.token,
|
||||
)
|
||||
else:
|
||||
@ -268,9 +269,27 @@ def main():
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
task="image-classification",
|
||||
)
|
||||
|
||||
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
|
||||
if data_args.image_column_name not in dataset_column_names:
|
||||
raise ValueError(
|
||||
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||
"Make sure to set `--image_column_name` to the correct audio column - one of "
|
||||
f"{', '.join(dataset_column_names)}."
|
||||
)
|
||||
if data_args.label_column_name not in dataset_column_names:
|
||||
raise ValueError(
|
||||
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||
"Make sure to set `--label_column_name` to the correct text column - one of "
|
||||
f"{', '.join(dataset_column_names)}."
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
|
||||
# If we don't have a validation split, split off a percentage of train as validation.
|
||||
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
|
||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||
@ -280,7 +299,7 @@ def main():
|
||||
|
||||
# Prepare label mappings.
|
||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||
labels = dataset["train"].features["labels"].names
|
||||
labels = dataset["train"].features[data_args.label_column_name].names
|
||||
label2id, id2label = {}, {}
|
||||
for i, label in enumerate(labels):
|
||||
label2id[label] = str(i)
|
||||
@ -354,13 +373,15 @@ def main():
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [
|
||||
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
|
||||
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
|
||||
]
|
||||
return example_batch
|
||||
|
||||
def val_transforms(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
|
||||
example_batch["pixel_values"] = [
|
||||
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
|
||||
]
|
||||
return example_batch
|
||||
|
||||
if training_args.do_train:
|
||||
|
@ -189,6 +189,18 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column_name",
|
||||
type=str,
|
||||
default="image",
|
||||
help="The name of the dataset column containing the image data. Defaults to 'image'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_column_name",
|
||||
type=str,
|
||||
default="label",
|
||||
help="The name of the dataset column containing the labels. Defaults to 'label'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@ -272,7 +284,7 @@ def main():
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(args.dataset_name, task="image-classification")
|
||||
dataset = load_dataset(args.dataset_name)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_dir is not None:
|
||||
@ -282,11 +294,24 @@ def main():
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
task="image-classification",
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
|
||||
|
||||
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
|
||||
if args.image_column_name not in dataset_column_names:
|
||||
raise ValueError(
|
||||
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
|
||||
"Make sure to set `--image_column_name` to the correct audio column - one of "
|
||||
f"{', '.join(dataset_column_names)}."
|
||||
)
|
||||
if args.label_column_name not in dataset_column_names:
|
||||
raise ValueError(
|
||||
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
|
||||
"Make sure to set `--label_column_name` to the correct text column - one of "
|
||||
f"{', '.join(dataset_column_names)}."
|
||||
)
|
||||
|
||||
# If we don't have a validation split, split off a percentage of train as validation.
|
||||
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
|
||||
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
|
||||
@ -296,7 +321,7 @@ def main():
|
||||
|
||||
# Prepare label mappings.
|
||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||
labels = dataset["train"].features["labels"].names
|
||||
labels = dataset["train"].features[args.label_column_name].names
|
||||
label2id = {label: str(i) for i, label in enumerate(labels)}
|
||||
id2label = {str(i): label for i, label in enumerate(labels)}
|
||||
|
||||
@ -355,12 +380,16 @@ def main():
|
||||
|
||||
def preprocess_train(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
||||
example_batch["pixel_values"] = [
|
||||
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
|
||||
]
|
||||
return example_batch
|
||||
|
||||
def preprocess_val(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
||||
example_batch["pixel_values"] = [
|
||||
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
|
||||
]
|
||||
return example_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
@ -376,7 +405,7 @@ def main():
|
||||
# DataLoaders creation:
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["labels"] for example in examples])
|
||||
labels = torch.tensor([example[args.label_column_name] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
|
@ -322,6 +322,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
--output_dir {tmp_dir}
|
||||
--with_tracking
|
||||
--checkpointing_steps 1
|
||||
--label_column_name labels
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + testargs)
|
||||
|
@ -398,6 +398,7 @@ class ExamplesTests(TestCasePlus):
|
||||
--max_steps 10
|
||||
--train_val_split 0.1
|
||||
--seed 42
|
||||
--label_column_name labels
|
||||
""".split()
|
||||
|
||||
if is_torch_fp16_available_on_device(torch_device):
|
||||
|
Loading…
Reference in New Issue
Block a user