mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Improve image classification example (#16585)
* Improve README * Make dataset_name argument optional * Improve local data * Fix bug * Improve README some more * Apply suggestions from code review * Improve README Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
3e4eec47f5
commit
048443db86
@ -14,13 +14,18 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Image classification examples
|
||||
# Image classification example
|
||||
|
||||
The following examples showcase how to fine-tune a `ViT` for image-classification using PyTorch.
|
||||
This directory contains a script, `run_image_classification.py`, that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT]((https://huggingface.co/docs/transformers/main/en/model_doc/convnext)), [ResNet]((https://huggingface.co/docs/transformers/main/en/model_doc/resnet)), [Swin Transformer]((https://huggingface.co/docs/transformers/main/en/model_doc/swin))...) using PyTorch. It can be used to fine-tune models on both well-known datasets (like [CIFAR-10](https://huggingface.co/datasets/cifar10), [Fashion MNIST](https://huggingface.co/datasets/fashion_mnist), ...) as well as on your own custom data.
|
||||
|
||||
## Using datasets from 🤗 `datasets`
|
||||
This page includes 2 sections:
|
||||
- [Using datasets from the hub](#using-datasets-from-🤗-hub)
|
||||
- [Using your own data](#using-your-own-data).
|
||||
|
||||
Here we show how to fine-tune a `ViT` on the [beans](https://huggingface.co/datasets/beans) dataset.
|
||||
|
||||
## Using datasets from 🤗 `Hub`
|
||||
|
||||
Here we show how to fine-tune a Vision Transformer (`ViT`) on the [beans](https://huggingface.co/datasets/beans) dataset, to classify the disease type of bean leaves.
|
||||
|
||||
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
|
||||
|
||||
@ -46,36 +51,21 @@ python run_image_classification.py \
|
||||
--seed 1337
|
||||
```
|
||||
|
||||
Here we show how to fine-tune a `ViT` on the [cats_vs_dogs](https://huggingface.co/datasets/cats_vs_dogs) dataset.
|
||||
To fine-tune another model, simply provide the `--model_name_or_path` argument. To train on another dataset, simply set the `--dataset_name` argument.
|
||||
|
||||
👀 See the results here: [nateraw/vit-base-cats-vs-dogs](https://huggingface.co/nateraw/vit-base-cats-vs-dogs).
|
||||
|
||||
```bash
|
||||
python run_image_classification.py \
|
||||
--dataset_name cats_vs_dogs \
|
||||
--output_dir ./cats_vs_dogs_outputs/ \
|
||||
--remove_unused_columns False \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--push_to_hub \
|
||||
--push_to_hub_model_id vit-base-cats-vs-dogs \
|
||||
--fp16 True \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 5 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--logging_strategy steps \
|
||||
--logging_steps 10 \
|
||||
--evaluation_strategy epoch \
|
||||
--save_strategy epoch \
|
||||
--load_best_model_at_end True \
|
||||
--save_total_limit 3 \
|
||||
--seed 1337
|
||||
```
|
||||
|
||||
## Using your own data
|
||||
|
||||
To use your own dataset, the training script expects the following directory structure:
|
||||
To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
|
||||
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
### Provide them as folders
|
||||
|
||||
If you provide your own folders with images, the script expects the following directory structure:
|
||||
|
||||
```bash
|
||||
root/dog/xxx.png
|
||||
@ -87,11 +77,10 @@ root/cat/nsdf3.png
|
||||
root/cat/[...]/asd932_.png
|
||||
```
|
||||
|
||||
Once you've prepared your dataset, you can can run the script like this:
|
||||
In other words, you need to organize your images in subfolders, based on their class. You can then run the script like this:
|
||||
|
||||
```bash
|
||||
python run_image_classification.py \
|
||||
--dataset_name nateraw/image-folder \
|
||||
--train_dir <path-to-train-root> \
|
||||
--output_dir ./outputs/ \
|
||||
--remove_unused_columns False \
|
||||
@ -99,12 +88,48 @@ python run_image_classification.py \
|
||||
--do_eval
|
||||
```
|
||||
|
||||
### 💡 The above will split the train dir into training and evaluation sets
|
||||
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
|
||||
|
||||
#### 💡 The above will split the train dir into training and evaluation sets
|
||||
- To control the split amount, use the `--train_val_split` flag.
|
||||
- To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag.
|
||||
|
||||
### Upload your data to the hub, as a (possibly private) repo
|
||||
|
||||
## Sharing your model on 🤗 Hub
|
||||
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# example 1: local folder
|
||||
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
|
||||
|
||||
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
|
||||
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
|
||||
|
||||
# example 3: remote files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
|
||||
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
|
||||
|
||||
# example 4: providing several splits
|
||||
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
|
||||
```
|
||||
|
||||
`ImageFolder` will create a `label` column, and the label name is based on the directory name.
|
||||
|
||||
Next, push it to the hub!
|
||||
|
||||
```python
|
||||
dataset.push_to_hub("name_of_your_dataset")
|
||||
|
||||
# if you want to push to a private repo, simply pass private=True:
|
||||
dataset.push_to_hub("name_of_your_dataset", private=True)
|
||||
```
|
||||
|
||||
and that's it! You can now simply train your model simply by setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the hub](#using-datasets-from-🤗-hub)).
|
||||
|
||||
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
|
||||
|
||||
# Sharing your model on 🤗 Hub
|
||||
|
||||
0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account
|
||||
|
||||
@ -116,13 +141,21 @@ $ git config --global user.email "you@example.com"
|
||||
$ git config --global user.name "Your Name"
|
||||
```
|
||||
|
||||
2. Log in with your HuggingFace account credentials using `huggingface-cli`
|
||||
2. Log in with your HuggingFace account credentials using `huggingface-cli`:
|
||||
|
||||
```bash
|
||||
$ huggingface-cli login
|
||||
# ...follow the prompts
|
||||
```
|
||||
|
||||
or, in case you're running in a notebook:
|
||||
|
||||
```python
|
||||
from huggingface_hub import notebook_login
|
||||
|
||||
notebook_login()
|
||||
```
|
||||
|
||||
3. When running the script, pass the following arguments:
|
||||
|
||||
```bash
|
||||
|
@ -72,13 +72,15 @@ def pil_loader(path: str):
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
|
||||
them on the command line.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default="nateraw/image-folder", metadata={"help": "Name of a dataset from the datasets package"}
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
|
||||
},
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
@ -104,12 +106,10 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
data_files = dict()
|
||||
if self.train_dir is not None:
|
||||
data_files["train"] = self.train_dir
|
||||
if self.validation_dir is not None:
|
||||
data_files["val"] = self.validation_dir
|
||||
self.data_files = data_files if data_files else None
|
||||
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
|
||||
raise ValueError(
|
||||
"You must specify either a dataset name from the hub or a train and/or validation directory."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -201,25 +201,37 @@ def main():
|
||||
)
|
||||
|
||||
# Initialize our dataset and prepare it for the 'image-classification' task.
|
||||
ds = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
data_files=data_args.data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
task="image-classification",
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
if data_args.dataset_name is not None:
|
||||
dataset = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
task="image-classification",
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_dir is not None:
|
||||
data_files["train"] = os.path.join(data_args.train_dir, "**")
|
||||
if data_args.validation_dir is not None:
|
||||
data_files["validation"] = os.path.join(data_args.validation_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
task="image-classification",
|
||||
)
|
||||
|
||||
# If we don't have a validation split, split off a percentage of train as validation.
|
||||
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
|
||||
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
|
||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||
split = ds["train"].train_test_split(data_args.train_val_split)
|
||||
ds["train"] = split["train"]
|
||||
ds["validation"] = split["test"]
|
||||
split = dataset["train"].train_test_split(data_args.train_val_split)
|
||||
dataset["train"] = split["train"]
|
||||
dataset["validation"] = split["test"]
|
||||
|
||||
# Prepare label mappings.
|
||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
||||
labels = ds["train"].features["labels"].names
|
||||
labels = dataset["train"].features["labels"].names
|
||||
label2id, id2label = dict(), dict()
|
||||
for i, label in enumerate(labels):
|
||||
label2id[label] = str(i)
|
||||
@ -291,29 +303,31 @@ def main():
|
||||
return example_batch
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in ds:
|
||||
if "train" not in dataset:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
if data_args.max_train_samples is not None:
|
||||
ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||
dataset["train"] = (
|
||||
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||
)
|
||||
# Set the training transforms
|
||||
ds["train"].set_transform(train_transforms)
|
||||
dataset["train"].set_transform(train_transforms)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in ds:
|
||||
if "validation" not in dataset:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
if data_args.max_eval_samples is not None:
|
||||
ds["validation"] = (
|
||||
ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||
dataset["validation"] = (
|
||||
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||
)
|
||||
# Set the validation transforms
|
||||
ds["validation"].set_transform(val_transforms)
|
||||
dataset["validation"].set_transform(val_transforms)
|
||||
|
||||
# Initalize our trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=feature_extractor,
|
||||
data_collator=collate_fn,
|
||||
@ -343,7 +357,7 @@ def main():
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"tasks": "image-classification",
|
||||
"dataset": data_args.dataset_name,
|
||||
"tags": ["image-classification"],
|
||||
"tags": ["image-classification", "vision"],
|
||||
}
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user