mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Remove task
arg in load_dataset
in image-classification example (#28408)
* Remove `task` arg in `load_dataset` in image-classification example * Manage case where "train" is not in dataset * Add new args to manage image and label column names * Similar to audio-classification example * Fix README * Update tests
This commit is contained in:
parent
edb170238f
commit
0cdcd7a2b3
@ -41,6 +41,7 @@ python run_image_classification.py \
|
|||||||
--dataset_name beans \
|
--dataset_name beans \
|
||||||
--output_dir ./beans_outputs/ \
|
--output_dir ./beans_outputs/ \
|
||||||
--remove_unused_columns False \
|
--remove_unused_columns False \
|
||||||
|
--label_column_name labels \
|
||||||
--do_train \
|
--do_train \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--push_to_hub \
|
--push_to_hub \
|
||||||
@ -197,7 +198,7 @@ accelerate test
|
|||||||
that will check everything is ready for training. Finally, you can launch training with
|
that will check everything is ready for training. Finally, you can launch training with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate launch run_image_classification_trainer.py
|
accelerate launch run_image_classification_no_trainer.py --image_column_name img
|
||||||
```
|
```
|
||||||
|
|
||||||
This command is the same and will work for:
|
This command is the same and will work for:
|
||||||
|
@ -111,6 +111,14 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
image_column_name: str = field(
|
||||||
|
default="image",
|
||||||
|
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
|
||||||
|
)
|
||||||
|
label_column_name: str = field(
|
||||||
|
default="label",
|
||||||
|
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
|
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
|
||||||
@ -175,12 +183,6 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
|
||||||
labels = torch.tensor([example["labels"] for example in examples])
|
|
||||||
return {"pixel_values": pixel_values, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
@ -255,7 +257,6 @@ def main():
|
|||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
data_args.dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
task="image-classification",
|
|
||||||
token=model_args.token,
|
token=model_args.token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -268,9 +269,27 @@ def main():
|
|||||||
"imagefolder",
|
"imagefolder",
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
task="image-classification",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
|
||||||
|
if data_args.image_column_name not in dataset_column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||||
|
"Make sure to set `--image_column_name` to the correct audio column - one of "
|
||||||
|
f"{', '.join(dataset_column_names)}."
|
||||||
|
)
|
||||||
|
if data_args.label_column_name not in dataset_column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||||
|
"Make sure to set `--label_column_name` to the correct text column - one of "
|
||||||
|
f"{', '.join(dataset_column_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
|
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
|
||||||
|
return {"pixel_values": pixel_values, "labels": labels}
|
||||||
|
|
||||||
# If we don't have a validation split, split off a percentage of train as validation.
|
# If we don't have a validation split, split off a percentage of train as validation.
|
||||||
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
|
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
|
||||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||||
@ -280,7 +299,7 @@ def main():
|
|||||||
|
|
||||||
# Prepare label mappings.
|
# Prepare label mappings.
|
||||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||||
labels = dataset["train"].features["labels"].names
|
labels = dataset["train"].features[data_args.label_column_name].names
|
||||||
label2id, id2label = {}, {}
|
label2id, id2label = {}, {}
|
||||||
for i, label in enumerate(labels):
|
for i, label in enumerate(labels):
|
||||||
label2id[label] = str(i)
|
label2id[label] = str(i)
|
||||||
@ -354,13 +373,15 @@ def main():
|
|||||||
def train_transforms(example_batch):
|
def train_transforms(example_batch):
|
||||||
"""Apply _train_transforms across a batch."""
|
"""Apply _train_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [
|
example_batch["pixel_values"] = [
|
||||||
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
|
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
|
||||||
]
|
]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
def val_transforms(example_batch):
|
def val_transforms(example_batch):
|
||||||
"""Apply _val_transforms across a batch."""
|
"""Apply _val_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
|
example_batch["pixel_values"] = [
|
||||||
|
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
|
||||||
|
]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
@ -189,6 +189,18 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
|
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_column_name",
|
||||||
|
type=str,
|
||||||
|
default="image",
|
||||||
|
help="The name of the dataset column containing the image data. Defaults to 'image'.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--label_column_name",
|
||||||
|
type=str,
|
||||||
|
default="label",
|
||||||
|
help="The name of the dataset column containing the labels. Defaults to 'label'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
@ -272,7 +284,7 @@ def main():
|
|||||||
# download the dataset.
|
# download the dataset.
|
||||||
if args.dataset_name is not None:
|
if args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
dataset = load_dataset(args.dataset_name, task="image-classification")
|
dataset = load_dataset(args.dataset_name)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
if args.train_dir is not None:
|
if args.train_dir is not None:
|
||||||
@ -282,11 +294,24 @@ def main():
|
|||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"imagefolder",
|
"imagefolder",
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
task="image-classification",
|
|
||||||
)
|
)
|
||||||
# See more about loading custom images at
|
# See more about loading custom images at
|
||||||
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
|
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
|
||||||
|
|
||||||
|
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
|
||||||
|
if args.image_column_name not in dataset_column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
|
||||||
|
"Make sure to set `--image_column_name` to the correct audio column - one of "
|
||||||
|
f"{', '.join(dataset_column_names)}."
|
||||||
|
)
|
||||||
|
if args.label_column_name not in dataset_column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
|
||||||
|
"Make sure to set `--label_column_name` to the correct text column - one of "
|
||||||
|
f"{', '.join(dataset_column_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
# If we don't have a validation split, split off a percentage of train as validation.
|
# If we don't have a validation split, split off a percentage of train as validation.
|
||||||
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
|
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
|
||||||
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
|
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
|
||||||
@ -296,7 +321,7 @@ def main():
|
|||||||
|
|
||||||
# Prepare label mappings.
|
# Prepare label mappings.
|
||||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||||
labels = dataset["train"].features["labels"].names
|
labels = dataset["train"].features[args.label_column_name].names
|
||||||
label2id = {label: str(i) for i, label in enumerate(labels)}
|
label2id = {label: str(i) for i, label in enumerate(labels)}
|
||||||
id2label = {str(i): label for i, label in enumerate(labels)}
|
id2label = {str(i): label for i, label in enumerate(labels)}
|
||||||
|
|
||||||
@ -355,12 +380,16 @@ def main():
|
|||||||
|
|
||||||
def preprocess_train(example_batch):
|
def preprocess_train(example_batch):
|
||||||
"""Apply _train_transforms across a batch."""
|
"""Apply _train_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
example_batch["pixel_values"] = [
|
||||||
|
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
|
||||||
|
]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
def preprocess_val(example_batch):
|
def preprocess_val(example_batch):
|
||||||
"""Apply _val_transforms across a batch."""
|
"""Apply _val_transforms across a batch."""
|
||||||
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
|
example_batch["pixel_values"] = [
|
||||||
|
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
|
||||||
|
]
|
||||||
return example_batch
|
return example_batch
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
@ -376,7 +405,7 @@ def main():
|
|||||||
# DataLoaders creation:
|
# DataLoaders creation:
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
labels = torch.tensor([example["labels"] for example in examples])
|
labels = torch.tensor([example[args.label_column_name] for example in examples])
|
||||||
return {"pixel_values": pixel_values, "labels": labels}
|
return {"pixel_values": pixel_values, "labels": labels}
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
|
@ -322,6 +322,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--output_dir {tmp_dir}
|
--output_dir {tmp_dir}
|
||||||
--with_tracking
|
--with_tracking
|
||||||
--checkpointing_steps 1
|
--checkpointing_steps 1
|
||||||
|
--label_column_name labels
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
|
@ -398,6 +398,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--max_steps 10
|
--max_steps 10
|
||||||
--train_val_split 0.1
|
--train_val_split 0.1
|
||||||
--seed 42
|
--seed 42
|
||||||
|
--label_column_name labels
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_torch_fp16_available_on_device(torch_device):
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
|
Loading…
Reference in New Issue
Block a user