mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Add: tensorflow example for image classification task guide (#21038)
* Added TF example for image classification * Code style polishing * code style polishing * minor polishing * fixed a link in a tip, and a typo in the inference TF content * Apply Amy's suggestions from review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/tasks/image_classification.mdx Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * review feedback addressed * make style * added PushToHubCallback with save_strategy="no" * minor polishing * added PushToHubCallback with save_strategy=no * minor polishing * Update docs/source/en/tasks/image_classification.mdx * added data augmentation Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make style Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
parent
3a9bd972e2
commit
868d37165f
@ -16,12 +16,14 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<Youtube id="tjAIM7BOYhw"/>
|
||||
|
||||
Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that comprise an image. There are many applications for image classification such as detecting damage after a natural disaster, monitoring crop health, or helping screen medical images for signs of disease.
|
||||
Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the
|
||||
pixel values that comprise an image. There are many applications for image classification, such as detecting damage
|
||||
after a natural disaster, monitoring crop health, or helping screen medical images for signs of disease.
|
||||
|
||||
This guide will show you how to:
|
||||
This guide illustrates how to:
|
||||
|
||||
1. Finetune [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) on the [Food-101](https://huggingface.co/datasets/food101) dataset to classify a food item in an image.
|
||||
2. Use your finetuned model for inference.
|
||||
1. Fine-tune [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) on the [Food-101](https://huggingface.co/datasets/food101) dataset to classify a food item in an image.
|
||||
2. Use your fine-tuned model for inference.
|
||||
|
||||
<Tip>
|
||||
|
||||
@ -35,7 +37,7 @@ Before you begin, make sure you have all the necessary libraries installed:
|
||||
pip install transformers datasets evaluate
|
||||
```
|
||||
|
||||
We encourage you to login to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to login:
|
||||
We encourage you to log in to your Hugging Face account to upload and share your model with the community. When prompted, enter your token to log in:
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
@ -45,7 +47,8 @@ We encourage you to login to your Hugging Face account so you can upload and sha
|
||||
|
||||
## Load Food-101 dataset
|
||||
|
||||
Start by loading a smaller subset of the Food-101 dataset from the 🤗 Datasets library. This'll give you a chance to experiment and make sure everythings works before spending more time training on the full dataset.
|
||||
Start by loading a smaller subset of the Food-101 dataset from the 🤗 Datasets library. This will give you a chance to
|
||||
experiment and make sure everything works before spending more time training on the full dataset.
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
@ -67,12 +70,13 @@ Then take a look at an example:
|
||||
'label': 79}
|
||||
```
|
||||
|
||||
There are two fields:
|
||||
Each example in the dataset has two fields:
|
||||
|
||||
- `image`: a PIL image of the food item.
|
||||
- `label`: the label class of the food item.
|
||||
- `image`: a PIL image of the food item
|
||||
- `label`: the label class of the food item
|
||||
|
||||
To make it easier for the model to get the label name from the label id, create a dictionary that maps the label name to an integer and vice versa:
|
||||
To make it easier for the model to get the label name from the label id, create a dictionary that maps the label name
|
||||
to an integer and vice versa:
|
||||
|
||||
```py
|
||||
>>> labels = food["train"].features["label"].names
|
||||
@ -96,9 +100,12 @@ The next step is to load a ViT image processor to process the image into a tenso
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
>>> checkpoint = "google/vit-base-patch16-224-in21k"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
```
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Apply some image transformations to the images to make the model more robust against overfitting. Here you'll use torchvision's [`transforms`](https://pytorch.org/vision/stable/transforms.html) module, but you can also use any image library you like.
|
||||
|
||||
Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:
|
||||
@ -130,17 +137,108 @@ To apply the preprocessing function over the entire dataset, use 🤗 Datasets [
|
||||
>>> food = food.with_transform(transforms)
|
||||
```
|
||||
|
||||
Now create a batch of examples using [`DataCollatorWithPadding`]. Unlike other data collators in 🤗 Transformers, the `DefaultDataCollator` does not apply additional preprocessing such as padding.
|
||||
Now create a batch of examples using [`DefaultDataCollator`]. Unlike other data collators in 🤗 Transformers, the `DefaultDataCollator` does not apply additional preprocessing such as padding.
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator()
|
||||
```
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
To avoid overfitting and to make the model more robust, add some data augmentation to the training part of the dataset.
|
||||
Here we use Keras preprocessing layers to define the transformations for the training data (includes data augmentation),
|
||||
and transformations for the validation data (only center cropping, resizing and normalizing). You can use `tf.image`or
|
||||
any other library you prefer.
|
||||
|
||||
```py
|
||||
>>> from tensorflow import keras
|
||||
>>> from tensorflow.keras import layers
|
||||
|
||||
>>> size = (image_processor.size["height"], image_processor.size["width"])
|
||||
|
||||
>>> train_data_augmentation = keras.Sequential(
|
||||
... [
|
||||
... layers.RandomCrop(size[0], size[1]),
|
||||
... layers.Rescaling(scale=1.0 / 127.5, offset=-1),
|
||||
... layers.RandomFlip("horizontal"),
|
||||
... layers.RandomRotation(factor=0.02),
|
||||
... layers.RandomZoom(height_factor=0.2, width_factor=0.2),
|
||||
... ],
|
||||
... name="train_data_augmentation",
|
||||
... )
|
||||
|
||||
>>> val_data_augmentation = keras.Sequential(
|
||||
... [
|
||||
... layers.CenterCrop(size[0], size[1]),
|
||||
... layers.Rescaling(scale=1.0 / 127.5, offset=-1),
|
||||
... ],
|
||||
... name="val_data_augmentation",
|
||||
... )
|
||||
```
|
||||
|
||||
Next, create functions to apply appropriate transformations to a batch of images, instead of one image at a time.
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import tensorflow as tf
|
||||
>>> from PIL import Image
|
||||
|
||||
|
||||
>>> def convert_to_tf_tensor(image: Image):
|
||||
... np_image = np.array(image)
|
||||
... tf_image = tf.convert_to_tensor(np_image)
|
||||
... # `expand_dims()` is used to add a batch dimension since
|
||||
... # the TF augmentation layers operates on batched inputs.
|
||||
... return tf.expand_dims(tf_image, 0)
|
||||
|
||||
|
||||
>>> def preprocess_train(example_batch):
|
||||
... """Apply train_transforms across a batch."""
|
||||
... images = [
|
||||
... train_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
|
||||
... ]
|
||||
... example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
|
||||
... return example_batch
|
||||
|
||||
|
||||
... def preprocess_val(example_batch):
|
||||
... """Apply val_transforms across a batch."""
|
||||
... images = [
|
||||
... val_data_augmentation(convert_to_tf_tensor(image.convert("RGB"))) for image in example_batch["image"]
|
||||
... ]
|
||||
... example_batch["pixel_values"] = [tf.transpose(tf.squeeze(image)) for image in images]
|
||||
... return example_batch
|
||||
```
|
||||
|
||||
Use 🤗 Datasets [`~datasets.Dataset.set_transform`] to apply the transformations on the fly:
|
||||
|
||||
```py
|
||||
food["train"].set_transform(preprocess_train)
|
||||
food["test"].set_transform(preprocess_val)
|
||||
```
|
||||
|
||||
As a final preprocessing step, create a batch of examples using `DefaultDataCollator`. Unlike other data collators in 🤗 Transformers, the
|
||||
`DefaultDataCollator` does not apply additional preprocessing, such as padding.
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator(return_tensors="tf")
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Evaluate
|
||||
|
||||
Including a metric during training is often helpful for evaluating your model's performance. You can quickly load a evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load the [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):
|
||||
Including a metric during training is often helpful for evaluating your model's performance. You can quickly load an
|
||||
evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load
|
||||
the [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):
|
||||
|
||||
```py
|
||||
>>> import evaluate
|
||||
@ -155,11 +253,12 @@ Then create a function that passes your predictions and labels to [`~evaluate.Ev
|
||||
|
||||
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... predictions = np.argmax(eval_pred.predictions, axis=1)
|
||||
... return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
||||
... predictions, labels = eval_pred
|
||||
... predictions = np.argmax(predictions, axis=1)
|
||||
... return accuracy.compute(predictions=predictions, references=labels)
|
||||
```
|
||||
|
||||
Your `compute_metrics` function is ready to go now, and you'll return to it when you setup your training.
|
||||
Your `compute_metrics` function is ready to go now, and you'll return to it when you set up your training.
|
||||
|
||||
## Train
|
||||
|
||||
@ -170,13 +269,14 @@ Your `compute_metrics` function is ready to go now, and you'll return to it when
|
||||
If you aren't familiar with finetuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#train-with-pytorch-trainer)!
|
||||
|
||||
</Tip>
|
||||
|
||||
You're ready to start training your model now! Load ViT with [`AutoModelForImageClassification`]. Specify the number of labels along with the number of expected labels, and the label mappings:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForImageClassification.from_pretrained(
|
||||
... "google/vit-base-patch16-224-in21k",
|
||||
... checkpoint,
|
||||
... num_labels=len(labels),
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
@ -228,6 +328,115 @@ Once training is completed, share your model to the Hub with the [`~transformers
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are unfamiliar with fine-tuning a model with Keras, check out the [basic tutorial](./training#train-a-tensorflow-model-with-keras) first!
|
||||
|
||||
</Tip>
|
||||
|
||||
To fine-tune a model in TensorFlow, follow these steps:
|
||||
1. Define the training hyperparameters, and set up an optimizer and a learning rate schedule.
|
||||
2. Instantiate a pre-treined model.
|
||||
3. Convert a 🤗 Dataset to a `tf.data.Dataset`.
|
||||
4. Compile your model.
|
||||
5. Add callbacks and use the `fit()` method to run the training.
|
||||
6. Upload your model to 🤗 Hub to share with the community.
|
||||
|
||||
Start by defining the hyperparameters, optimizer and learning rate schedule:
|
||||
|
||||
```py
|
||||
>>> from transformers import create_optimizer
|
||||
|
||||
>>> batch_size = 16
|
||||
>>> num_epochs = 5
|
||||
>>> num_train_steps = len(food["train"]) * num_epochs
|
||||
>>> learning_rate = 3e-5
|
||||
>>> weight_decay_rate = 0.01
|
||||
|
||||
>>> optimizer, lr_schedule = create_optimizer(
|
||||
... init_lr=learning_rate,
|
||||
... num_train_steps=num_train_steps,
|
||||
... weight_decay_rate=weight_decay_rate,
|
||||
... num_warmup_steps=0,
|
||||
... )
|
||||
```
|
||||
|
||||
Then, load ViT with [`TFAutoModelForImageClassification`] along with the label mappings:
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForImageClassification
|
||||
|
||||
>>> model = TFAutoModelForImageClassification.from_pretrained(
|
||||
... checkpoint,
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... )
|
||||
```
|
||||
|
||||
Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Dataset.to_tf_dataset`] and your `data_collator`:
|
||||
|
||||
```py
|
||||
>>> # converting our train dataset to tf.data.Dataset
|
||||
>>> tf_train_dataset = food["train"].to_tf_dataset(
|
||||
... columns=["pixel_values"], label_cols=["label"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
|
||||
... )
|
||||
|
||||
>>> # converting our test dataset to tf.data.Dataset
|
||||
>>> tf_eval_dataset = food["test"].to_tf_dataset(
|
||||
... columns=["pixel_values"], label_cols=["label"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
|
||||
... )
|
||||
```
|
||||
|
||||
Configure the model for training with `compile()`:
|
||||
|
||||
```py
|
||||
>>> from tensorflow.keras.losses import SparseCategoricalCrossentropy
|
||||
|
||||
>>> loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
>>> model.compile(optimizer=optimizer, loss=loss)
|
||||
```
|
||||
|
||||
To compute the accuracy from the predictions and push your model to the 🤗 Hub, use [Keras callbacks](./main_classes/keras_callbacks).
|
||||
Pass your `compute_metrics` function to [KerasMetricCallback](./main_classes/keras_callbacks#transformers.KerasMetricCallback),
|
||||
and use the [PushToHubCallback](./main_classes/keras_callbacks#transformers.PushToHubCallback) to upload the model:
|
||||
|
||||
```py
|
||||
>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
|
||||
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
|
||||
>>> push_to_hub_callback = PushToHubCallback(
|
||||
... output_dir="food_classifier",
|
||||
... tokenizer=image_processor,
|
||||
... save_strategy="no",
|
||||
... )
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
||||
Finally, you are ready to train your model! Call `fit()` with your training and validation datasets, the number of epochs,
|
||||
and your callbacks to fine-tune the model:
|
||||
|
||||
```py
|
||||
>>> model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)
|
||||
Epoch 1/5
|
||||
250/250 [==============================] - 313s 1s/step - loss: 2.5623 - val_loss: 1.4161 - accuracy: 0.9290
|
||||
Epoch 2/5
|
||||
250/250 [==============================] - 265s 1s/step - loss: 0.9181 - val_loss: 0.6808 - accuracy: 0.9690
|
||||
Epoch 3/5
|
||||
250/250 [==============================] - 252s 1s/step - loss: 0.3910 - val_loss: 0.4303 - accuracy: 0.9820
|
||||
Epoch 4/5
|
||||
250/250 [==============================] - 251s 1s/step - loss: 0.2028 - val_loss: 0.3191 - accuracy: 0.9900
|
||||
Epoch 5/5
|
||||
250/250 [==============================] - 238s 949ms/step - loss: 0.1232 - val_loss: 0.3259 - accuracy: 0.9890
|
||||
```
|
||||
|
||||
Congratulations! You have fine-tuned your model and shared it on the 🤗 Hub. You can now use it for inference!
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
For a more in-depth example of how to finetune a model for image classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
@ -236,7 +445,7 @@ For a more in-depth example of how to finetune a model for image classification,
|
||||
|
||||
## Inference
|
||||
|
||||
Great, now that you've finetuned a model, you can use it for inference!
|
||||
Great, now that you've fine-tuned a model, you can use it for inference!
|
||||
|
||||
Load an image you'd like to run inference on:
|
||||
|
||||
@ -256,12 +465,13 @@ The simplest way to try out your finetuned model for inference is to use it in a
|
||||
|
||||
>>> classifier = pipeline("image-classification", model="my_awesome_food_model")
|
||||
>>> classifier(image)
|
||||
[{'score': 0.35574808716773987, 'label': 'beignets'},
|
||||
{'score': 0.018057454377412796, 'label': 'chicken_wings'},
|
||||
{'score': 0.017733804881572723, 'label': 'prime_rib'},
|
||||
{'score': 0.016335085034370422, 'label': 'bruschetta'},
|
||||
{'score': 0.0160061065107584, 'label': 'ramen'}]
|
||||
[{'score': 0.31856709718704224, 'label': 'beignets'},
|
||||
{'score': 0.015232225880026817, 'label': 'bruschetta'},
|
||||
{'score': 0.01519392803311348, 'label': 'chicken_wings'},
|
||||
{'score': 0.013022331520915031, 'label': 'pork_chop'},
|
||||
{'score': 0.012728818692266941, 'label': 'prime_rib'}]
|
||||
```
|
||||
|
||||
You can also manually replicate the results of the `pipeline` if you'd like:
|
||||
|
||||
<frameworkcontent>
|
||||
@ -294,4 +504,35 @@ Get the predicted label with the highest probability, and use the model's `id2la
|
||||
'beignets'
|
||||
```
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
Load an image processor to preprocess the image and return the `input` as TensorFlow tensors:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/food_classifier")
|
||||
>>> inputs = image_processor(image, return_tensors="tf")
|
||||
```
|
||||
|
||||
Pass your inputs to the model and return the logits:
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForImageClassification
|
||||
|
||||
>>> model = TFAutoModelForImageClassification.from_pretrained("MariaK/food_classifier")
|
||||
>>> logits = model(**inputs).logits
|
||||
```
|
||||
|
||||
Get the predicted label with the highest probability, and use the model's `id2label` mapping to convert it to a label:
|
||||
|
||||
```py
|
||||
>>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
|
||||
>>> model.config.id2label[predicted_class_id]
|
||||
'beignets'
|
||||
```
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
Loading…
Reference in New Issue
Block a user