mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Audio/vision task guides (#15808)
* 📝 first draft of audio/vision guides * ✨ make fixup * 🖍 fix typo * 🖍 close parentheses * 🖍 apply feedback * 🖍 apply feedback, make fixup * 🖍 more fixup for perceiver * 🖍 apply feedback * ✨ make fixup * 🖍 fix data collator
This commit is contained in:
parent
cb5e50c8c2
commit
ae2dd42be5
@ -56,6 +56,12 @@
|
||||
title: Summarization
|
||||
- local: tasks/multiple_choice
|
||||
title: Multiple choice
|
||||
- local: tasks/audio_classification
|
||||
title: Audio classification
|
||||
- local: tasks/asr
|
||||
title: Automatic speech recognition
|
||||
- local: tasks/image_classification
|
||||
title: Image classification
|
||||
title: Fine-tune for downstream tasks
|
||||
- local: run_scripts
|
||||
title: Train with a script
|
||||
|
214
docs/source/tasks/asr.mdx
Normal file
214
docs/source/tasks/asr.mdx
Normal file
@ -0,0 +1,214 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Automatic speech recognition
|
||||
|
||||
<Youtube id="TksaY_FDgnk"/>
|
||||
|
||||
Automatic speech recognition (ASR) converts a speech signal to text. It is an example of a sequence-to-sequence task, going from a sequence of audio inputs to textual outputs. Voice assistants like Siri and Alexa utilize ASR models to assist users.
|
||||
|
||||
This guide will show you how to fine-tune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [TIMIT](https://huggingface.co/datasets/timit_asr) dataset to transcribe audio to text.
|
||||
|
||||
<Tip>
|
||||
|
||||
See the automatic speech recognition [task page](https://huggingface.co/tasks/automatic-speech-recognition) for more information about its associated models, datasets, and metrics.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Load TIMIT dataset
|
||||
|
||||
Load the TIMIT dataset from the 🤗 Datasets library:
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> timit = load_dataset("timit_asr")
|
||||
```
|
||||
|
||||
Then take a look at an example:
|
||||
|
||||
```py
|
||||
>>> timit
|
||||
DatasetDict({
|
||||
train: Dataset({
|
||||
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
|
||||
num_rows: 4620
|
||||
})
|
||||
test: Dataset({
|
||||
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
|
||||
num_rows: 1680
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
While the dataset contains a lot of helpful information, like `dialect_region` and `sentence_type`, you will focus on the `audio` and `text` fields in this guide. Remove the other columns:
|
||||
|
||||
```py
|
||||
>>> timit = timit.remove_columns(
|
||||
... ["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"]
|
||||
... )
|
||||
```
|
||||
|
||||
Take a look at the example again:
|
||||
|
||||
```py
|
||||
>>> timit["train"][0]
|
||||
{'audio': {'array': array([-2.1362305e-04, 6.1035156e-05, 3.0517578e-05, ...,
|
||||
-3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
|
||||
'path': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
|
||||
'sampling_rate': 16000},
|
||||
'file': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
|
||||
'text': 'Would such an act of refusal be useful?'}
|
||||
```
|
||||
|
||||
The `audio` column contains a 1-dimensional `array` of the speech signal that must be called to load and resample the audio file.
|
||||
|
||||
## Preprocess
|
||||
|
||||
Load the Wav2Vec2 processor to process the audio signal and transcribed text:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoProcessor
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
|
||||
```
|
||||
|
||||
The preprocessing function needs to:
|
||||
|
||||
1. Call the `audio` column to load and resample the audio file.
|
||||
2. Extract the `input_values` from the audio file.
|
||||
3. Typically, when you call the processor, you call the feature extractor. Since you also want to tokenize text, instruct the processor to call the tokenizer instead with a context manager.
|
||||
|
||||
```py
|
||||
>>> def prepare_dataset(batch):
|
||||
... audio = batch["audio"]
|
||||
|
||||
... batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
|
||||
... batch["input_length"] = len(batch["input_values"])
|
||||
|
||||
... with processor.as_target_processor():
|
||||
... batch["labels"] = processor(batch["text"]).input_ids
|
||||
... return batch
|
||||
```
|
||||
|
||||
Use 🤗 Datasets [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) function to apply the preprocessing function over the entire dataset. You can speed up the map function by increasing the number of processes with `num_proc`. Remove the columns you don't need:
|
||||
|
||||
```py
|
||||
>>> timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)
|
||||
```
|
||||
|
||||
🤗 Transformers doesn't have a data collator for automatic speech recognition, so you will need to create one. You can adapt the [`DataCollatorWithPadding`] to create a batch of examples for automatic speech recognition. It will also dynamically pad your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the `tokenizer` function by setting `padding=True`, dynamic padding is more efficient.
|
||||
|
||||
Unlike other data collators, this specific data collator needs to apply a different padding method to `input_values` and `labels`. You can apply a different padding method with a context manager:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> from dataclasses import dataclass, field
|
||||
>>> from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorCTCWithPadding:
|
||||
|
||||
... processor: AutoProcessor
|
||||
... padding: Union[bool, str] = True
|
||||
|
||||
... def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
... # split inputs and labels since they have to be of different lengths and need
|
||||
... # different padding methods
|
||||
... input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
... label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
... batch = self.processor.pad(
|
||||
... input_features,
|
||||
... padding=self.padding,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
... with self.processor.as_target_processor():
|
||||
... labels_batch = self.processor.pad(
|
||||
... label_features,
|
||||
... padding=self.padding,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
... # replace padding with -100 to ignore loss correctly
|
||||
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
||||
... batch["labels"] = labels
|
||||
|
||||
... return batch
|
||||
```
|
||||
|
||||
Create a batch of examples and dynamically pad them with `DataCollatorForCTCWithPadding`:
|
||||
|
||||
```py
|
||||
>>> data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
||||
```
|
||||
|
||||
## Fine-tune with Trainer
|
||||
|
||||
Load Wav2Vec2 with [`AutoModelForCTC`]. For `ctc_loss_reduction`, it is often better to use the average instead of the default summation:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForCTC.from_pretrained(
|
||||
... "facebook/wav2vec-base",
|
||||
... ctc_loss_reduction="mean",
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... )
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](training#finetune-with-trainer)!
|
||||
|
||||
</Tip>
|
||||
|
||||
At this point, only three steps remain:
|
||||
|
||||
1. Define your training hyperparameters in [`TrainingArguments`].
|
||||
2. Pass the training arguments to [`Trainer`] along with the model, datasets, tokenizer, and data collator.
|
||||
3. Call [`~Trainer.train`] to fine-tune your model.
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="./results",
|
||||
... group_by_length=True,
|
||||
... per_device_train_batch_size=16,
|
||||
... evaluation_strategy="steps",
|
||||
... num_train_epochs=3,
|
||||
... fp16=True,
|
||||
... gradient_checkpointing=True,
|
||||
... learning_rate=1e-4,
|
||||
... weight_decay=0.005,
|
||||
... save_total_limit=2,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... train_dataset=timit["train"],
|
||||
... eval_dataset=timit["test"],
|
||||
... tokenizer=processor.feature_extractor,
|
||||
... data_collator=data_collator,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
For a more in-depth example of how to fine-tune a model for automatic speech recognition, take a look at this blog [post](https://huggingface.co/blog/fine-tune-wav2vec2-english) for English ASR and this [post](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) for multilingual ASR.
|
||||
|
||||
</Tip>
|
143
docs/source/tasks/audio_classification.mdx
Normal file
143
docs/source/tasks/audio_classification.mdx
Normal file
@ -0,0 +1,143 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Audio classification
|
||||
|
||||
<Youtube id="KWwzcmG98Ds"/>
|
||||
|
||||
Audio classification assigns a label or class to audio data. It is similar to text classification, except an audio input is continuous and must be discretized, whereas text can be split into tokens. Some practical applications of audio classification include identifying intent, speakers, and even animal species by their sounds.
|
||||
|
||||
This guide will show you how to fine-tune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the Keyword Spotting subset of the [SUPERB](https://huggingface.co/datasets/superb) benchmark to classify utterances.
|
||||
|
||||
<Tip>
|
||||
|
||||
See the audio classification [task page](https://huggingface.co/tasks/audio-classification) for more information about its associated models, datasets, and metrics.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Load SUPERB dataset
|
||||
|
||||
Load the SUPERB dataset from the 🤗 Datasets library:
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> ks = load_dataset("superb", "ks")
|
||||
```
|
||||
|
||||
Then take a look at an example:
|
||||
|
||||
```py
|
||||
>>> ks["train"][0]
|
||||
{'audio': {'array': array([ 0. , 0. , 0. , ..., -0.00592041, -0.00405884, -0.00253296], dtype=float32), 'path': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'sampling_rate': 16000}, 'file': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'label': 10}
|
||||
```
|
||||
|
||||
The `audio` column contains a 1-dimensional `array` of the speech signal that must be called to load and resample the audio file. The `label` column is an integer that represents the utterance class. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:
|
||||
|
||||
```py
|
||||
>>> labels = ks["train"].features["label"].names
|
||||
>>> label2id, id2label = dict(), dict()
|
||||
>>> for i, label in enumerate(labels):
|
||||
... label2id[label] = str(i)
|
||||
... id2label[str(i)] = label
|
||||
```
|
||||
|
||||
Now you can convert the label number to a label name for more information:
|
||||
|
||||
```py
|
||||
>>> id2label[str(10)]
|
||||
'_silence_'
|
||||
```
|
||||
|
||||
Each keyword - or label - corresponds to a number; `10` indicates `silence` in the example above.
|
||||
|
||||
## Preprocess
|
||||
|
||||
Load the Wav2Vec2 feature extractor to process the audio signal:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoFeatureExtractor
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
```
|
||||
|
||||
The preprocessing function needs to:
|
||||
|
||||
1. Call the `audio` column to load and if necessary resample the audio file.
|
||||
2. Check the sampling rate of the audio file matches the sampling rate of the audio data a model was pretrained with. You can find this information on the Wav2Vec2 [model card]((https://huggingface.co/facebook/wav2vec2-base)).
|
||||
3. Set a maximum input length so longer inputs are batched without being truncated.
|
||||
|
||||
```py
|
||||
>>> def preprocess_function(examples):
|
||||
... audio_arrays = [x["array"] for x in examples["audio"]]
|
||||
... inputs = feature_extractor(
|
||||
... audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True
|
||||
... )
|
||||
... return inputs
|
||||
```
|
||||
|
||||
Use 🤗 Datasets [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) function to apply the preprocessing function over the entire dataset. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once. Remove the columns you don't need:
|
||||
|
||||
```py
|
||||
>>> encoded_ks = ks.map(preprocess_function, remove_columns=["audio", "file"], batched=True)
|
||||
```
|
||||
|
||||
## Fine-tune with Trainer
|
||||
|
||||
Load Wav2Vec2 with [`AutoModelForAudioClassification`]. Specify the number of labels, and pass the model the mapping between label number and label class:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
|
||||
|
||||
>>> num_labels = len(id2label)
|
||||
>>> model = AutoModelForAudioClassification.from_pretrained(
|
||||
... "facebook/wav2vec2-base", num_labels=num_labels, label2id=label2id, id2label=id2label
|
||||
... )
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](training#finetune-with-trainer)!
|
||||
|
||||
</Tip>
|
||||
|
||||
At this point, only three steps remain:
|
||||
|
||||
1. Define your training hyperparameters in [`TrainingArguments`].
|
||||
2. Pass the training arguments to [`Trainer`] along with the model, datasets, and feature extractor.
|
||||
3. Call [`~Trainer.train`] to fine-tune your model.
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="./results",
|
||||
... evaluation_strategy="epoch",
|
||||
... save_strategy="epoch",
|
||||
... learning_rate=3e-5,
|
||||
... num_train_epochs=5,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... train_dataset=encoded_ks["train"],
|
||||
... eval_dataset=encoded_ks["validation"],
|
||||
... tokenizer=feature_extractor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
For a more in-depth example of how to fine-tune a model for audio classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/audio_classification.ipynb).
|
||||
|
||||
</Tip>
|
170
docs/source/tasks/image_classification.mdx
Normal file
170
docs/source/tasks/image_classification.mdx
Normal file
@ -0,0 +1,170 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Image classification
|
||||
|
||||
<Youtube id="tjAIM7BOYhw"/>
|
||||
|
||||
Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that represent an image. There are many uses for image classification, like detecting damage after a disaster, monitoring crop health, or helping screen medical images for signs of disease.
|
||||
|
||||
This guide will show you how to fine-tune [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) on the [Food-101](https://huggingface.co/datasets/food101) dataset to classify a food item in an image.
|
||||
|
||||
<Tip>
|
||||
|
||||
See the image classification [task page](https://huggingface.co/tasks/audio-classification) for more information about its associated models, datasets, and metrics.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Load Food-101 dataset
|
||||
|
||||
Load only the first 5000 images of the Food-101 dataset from the 🤗 Datasets library since it is pretty large:
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> food = load_dataset("food101", split="train[:5000]")
|
||||
```
|
||||
|
||||
Split this dataset into a train and test set:
|
||||
|
||||
```py
|
||||
>>> food = food.train_test_split(test_size=0.2)
|
||||
```
|
||||
|
||||
Then take a look at an example:
|
||||
|
||||
```py
|
||||
>>> food["train"][0]
|
||||
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
|
||||
'label': 79}
|
||||
```
|
||||
|
||||
The `image` field contains a PIL image, and each `label` is an integer that represents a class. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:
|
||||
|
||||
```py
|
||||
>>> labels = food["train"].features["label"].names
|
||||
>>> label2id, id2label = dict(), dict()
|
||||
>>> for i, label in enumerate(labels):
|
||||
... label2id[label] = str(i)
|
||||
... id2label[str(i)] = label
|
||||
```
|
||||
|
||||
Now you can convert the label number to a label name for more information:
|
||||
|
||||
```py
|
||||
>>> id2label[str(79)]
|
||||
'prime_rib'
|
||||
```
|
||||
|
||||
Each food class - or label - corresponds to a number; `79` indicates a prime rib in the example above.
|
||||
|
||||
## Preprocess
|
||||
|
||||
Load the ViT feature extractor to process the image into a tensor:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoFeatureExtractor
|
||||
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
```
|
||||
|
||||
Apply several image transformations to the dataset to make the model more robust against overfitting. Here you'll use torchvision's [`transforms`](https://pytorch.org/vision/stable/transforms.html) module. Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:
|
||||
|
||||
```py
|
||||
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
|
||||
|
||||
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
|
||||
```
|
||||
|
||||
Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:
|
||||
|
||||
```py
|
||||
>>> def transforms(examples):
|
||||
... examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
|
||||
... del examples["image"]
|
||||
... return examples
|
||||
```
|
||||
|
||||
Use 🤗 Dataset's [`with_transform`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?#datasets.Dataset.with_transform) method to apply the transforms over the entire dataset. The transforms are applied on-the-fly when you load an element of the dataset:
|
||||
|
||||
```py
|
||||
>>> food = food.with_transform(transforms)
|
||||
```
|
||||
|
||||
Use [`DefaultDataCollator`] to create a batch of examples. Unlike other data collators in 🤗 Transformers, the DefaultDataCollator does not apply additional preprocessing such as padding.
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator()
|
||||
```
|
||||
|
||||
## Fine-tune with Trainer
|
||||
|
||||
Load ViT with [`AutoModelForImageClassification`]. Specify the number of labels, and pass the model the mapping between label number and label class:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForImageClassification.from_pretrained(
|
||||
... "google/vit-base-patch16-224-in21k",
|
||||
... num_labels=len(labels),
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... )
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](training#finetune-with-trainer)!
|
||||
|
||||
</Tip>
|
||||
|
||||
At this point, only three steps remain:
|
||||
|
||||
1. Define your training hyperparameters in [`TrainingArguments`]. It is important you don't remove unused columns because this will drop the `image` column. Without the `image` column, you can't create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior!
|
||||
2. Pass the training arguments to [`Trainer`] along with the model, datasets, tokenizer, and data collator.
|
||||
3. Call [`~Trainer.train`] to fine-tune your model.
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="./results",
|
||||
... per_device_train_batch_size=16,
|
||||
... evaluation_strategy="steps",
|
||||
... num_train_epochs=4,
|
||||
... fp16=True,
|
||||
... save_steps=100,
|
||||
... eval_steps=100,
|
||||
... logging_steps=10,
|
||||
... learning_rate=2e-4,
|
||||
... save_total_limit=2,
|
||||
... remove_unused_columns=False,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... tokenizer=feature_extractor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
For a more in-depth example of how to fine-tune a model for image classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/image_classification.ipynb).
|
||||
|
||||
</Tip>
|
Loading…
Reference in New Issue
Block a user