[Trainer] Allow passing image processor (#29896)

* Add image processor to trainer

* Replace tokenizer=image_processor everywhere
This commit is contained in:
NielsRogge 2024-04-05 10:10:44 +02:00 committed by GitHub
parent d704c0b698
commit 1ab7136488
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 43 additions and 26 deletions

View File

@ -322,7 +322,7 @@ At this point, only three steps remain:
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... )
@ -418,7 +418,7 @@ and use the [PushToHubCallback](../main_classes/keras_callbacks#transformers.Pus
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
>>> push_to_hub_callback = PushToHubCallback(
... output_dir="food_classifier",
... tokenizer=image_processor,
... image_processor=image_processor,
... save_strategy="no",
... )
>>> callbacks = [metric_callback, push_to_hub_callback]

View File

@ -384,7 +384,7 @@ Finally, bring everything together, and call [`~transformers.Trainer.train`]:
... args=training_args,
... data_collator=collate_fn,
... train_dataset=cppe5["train"],
... tokenizer=image_processor,
... image_processor=image_processor,
... )
>>> trainer.train()

View File

@ -642,7 +642,7 @@ and use the [`PushToHubCallback`] to upload the model:
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
```

View File

@ -407,7 +407,7 @@ Then you just pass all of this along with the datasets to `Trainer`:
... args,
... train_dataset=train_dataset,
... eval_dataset=val_dataset,
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... data_collator=collate_fn,
... )

View File

@ -160,7 +160,7 @@ Al llegar a este punto, solo quedan tres pasos:
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... tokenizer=image_processor,
... image_processor=image_processor,
... )
>>> trainer.train()

View File

@ -328,7 +328,7 @@ food["test"].set_transform(preprocess_val)
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... )
@ -426,7 +426,7 @@ Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Data
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
>>> push_to_hub_callback = PushToHubCallback(
... output_dir="food_classifier",
... tokenizer=image_processor,
... image_processor=image_processor,
... save_strategy="no",
... )
>>> callbacks = [metric_callback, push_to_hub_callback]

View File

@ -376,7 +376,7 @@ DETR モデルをトレーニングできる「ラベル」。画像プロセッ
... args=training_args,
... data_collator=collate_fn,
... train_dataset=cppe5["train"],
... tokenizer=image_processor,
... image_processor=image_processor,
... )
>>> trainer.train()

View File

@ -434,7 +434,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
```

View File

@ -436,7 +436,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
```

View File

@ -414,7 +414,7 @@ def compute_metrics(eval_pred):
... args,
... train_dataset=train_dataset,
... eval_dataset=val_dataset,
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... data_collator=collate_fn,
... )

View File

@ -321,7 +321,7 @@ food["test"].set_transform(preprocess_val)
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... )
@ -417,7 +417,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요:
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
>>> push_to_hub_callback = PushToHubCallback(
... output_dir="food_classifier",
... tokenizer=image_processor,
... image_processor=image_processor,
... save_strategy="no",
... )
>>> callbacks = [metric_callback, push_to_hub_callback]

View File

@ -366,7 +366,7 @@ DatasetDict({
... args=training_args,
... data_collator=collate_fn,
... train_dataset=cppe5["train"],
... tokenizer=image_processor,
... image_processor=image_processor,
... )
>>> trainer.train()

View File

@ -424,7 +424,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요:
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
```

View File

@ -411,7 +411,7 @@ def compute_metrics(eval_pred):
... args,
... train_dataset=train_dataset,
... eval_dataset=val_dataset,
... tokenizer=image_processor,
... image_processor=image_processor,
... compute_metrics=compute_metrics,
... data_collator=collate_fn,
... )

View File

@ -411,7 +411,7 @@ def main():
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=image_processor,
image_processor=image_processor,
data_collator=collate_fn,
)

View File

@ -369,7 +369,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=image_processor,
image_processor=image_processor,
data_collator=collate_fn,
)

View File

@ -458,7 +458,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=image_processor,
image_processor=image_processor,
data_collator=collate_fn,
)

View File

@ -510,7 +510,7 @@ def main():
train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=image_processor,
image_processor=image_processor,
data_collator=default_data_collator,
)

View File

@ -552,7 +552,7 @@ def main():
output_dir=training_args.output_dir,
hub_model_id=push_to_hub_model_id,
hub_token=training_args.push_to_hub_token,
tokenizer=image_processor,
image_processor=image_processor,
**model_card_kwargs,
)
)

View File

@ -59,6 +59,7 @@ from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .image_processing_utils import BaseImageProcessor
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
@ -303,6 +304,9 @@ class Trainer:
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
image_processor ([`BaseImageProcessor`], *optional*):
The image processor used to preprocess the data. If provided, it will be saved along the model to make it easier
to rerun an interrupted training or reuse the fine-tuned model.
model_init (`Callable[[], PreTrainedModel]`, *optional*):
A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
from a new instance of the model as given by this function.
@ -357,6 +361,7 @@ class Trainer:
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
image_processor: Optional["BaseImageProcessor"] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
@ -485,11 +490,12 @@ class Trainer:
):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.image_processor = image_processor
# Bnb Quantized models doesn't support `.to` operation.
if (
@ -541,7 +547,7 @@ class Trainer:
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
callbacks, self.model, self.tokenizer, self.image_processor, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
@ -3276,6 +3282,8 @@ class Trainer:
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
if self.image_processor is not None and self.args.should_save:
self.image_processor.save_pretrained(output_dir)
# We moved the model from TPU -> CPU for saving the weights.
# Now we should move it back to subsequent compute still works.
@ -3313,6 +3321,8 @@ class Trainer:
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
if self.image_processor is not None:
self.image_processor.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
@ -4009,6 +4019,9 @@ class Trainer:
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the image processor
if self.image_processor is not None:
self.image_processor.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
@ -4056,7 +4069,7 @@ class Trainer:
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (`str`, *optional*, defaults to `"End of training"`):

View File

@ -189,6 +189,8 @@ class TrainerCallback:
The model being trained.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
image_processor ([`BaseImageProcessor`]):
The image processor used for encoding the images.
optimizer (`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
@ -307,12 +309,13 @@ class TrainerCallback:
class CallbackHandler(TrainerCallback):
"""Internal class that just calls the list of callbacks in order."""
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
def __init__(self, callbacks, model, tokenizer, image_processor, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.tokenizer = tokenizer
self.image_processor = image_processor
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
@ -417,6 +420,7 @@ class CallbackHandler(TrainerCallback):
control,
model=self.model,
tokenizer=self.tokenizer,
image_processor=self.image_processor,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,