From 1ab71364886010c31b20dd8c8bb0c60f8a0681ad Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Fri, 5 Apr 2024 10:10:44 +0200 Subject: [PATCH] [Trainer] Allow passing image processor (#29896) * Add image processor to trainer * Replace tokenizer=image_processor everywhere --- docs/source/en/tasks/image_classification.md | 4 ++-- docs/source/en/tasks/object_detection.md | 2 +- docs/source/en/tasks/semantic_segmentation.md | 2 +- docs/source/en/tasks/video_classification.md | 2 +- docs/source/es/tasks/image_classification.md | 2 +- docs/source/ja/tasks/image_classification.md | 4 ++-- docs/source/ja/tasks/object_detection.md | 2 +- docs/source/ja/tasks/semantic_segmentation.md | 2 +- .../ja/tasks/sequence_classification.md | 2 +- docs/source/ja/tasks/video_classification.md | 2 +- docs/source/ko/tasks/image_classification.md | 4 ++-- docs/source/ko/tasks/object_detection.md | 2 +- docs/source/ko/tasks/semantic_segmentation.md | 2 +- docs/source/ko/tasks/video_classification.md | 2 +- .../run_image_classification.py | 2 +- examples/pytorch/image-pretraining/run_mae.py | 2 +- examples/pytorch/image-pretraining/run_mim.py | 2 +- .../run_semantic_segmentation.py | 2 +- .../run_image_classification.py | 2 +- src/transformers/trainer.py | 19 ++++++++++++++++--- src/transformers/trainer_callback.py | 6 +++++- 21 files changed, 43 insertions(+), 26 deletions(-) diff --git a/docs/source/en/tasks/image_classification.md b/docs/source/en/tasks/image_classification.md index 30c517f3be6..f54b4ed025d 100644 --- a/docs/source/en/tasks/image_classification.md +++ b/docs/source/en/tasks/image_classification.md @@ -322,7 +322,7 @@ At this point, only three steps remain: ... data_collator=data_collator, ... train_dataset=food["train"], ... eval_dataset=food["test"], -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... ) @@ -418,7 +418,7 @@ and use the [PushToHubCallback](../main_classes/keras_callbacks#transformers.Pus >>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset) >>> push_to_hub_callback = PushToHubCallback( ... output_dir="food_classifier", -... tokenizer=image_processor, +... image_processor=image_processor, ... save_strategy="no", ... ) >>> callbacks = [metric_callback, push_to_hub_callback] diff --git a/docs/source/en/tasks/object_detection.md b/docs/source/en/tasks/object_detection.md index 2513591f545..56d46e4aa52 100644 --- a/docs/source/en/tasks/object_detection.md +++ b/docs/source/en/tasks/object_detection.md @@ -384,7 +384,7 @@ Finally, bring everything together, and call [`~transformers.Trainer.train`]: ... args=training_args, ... data_collator=collate_fn, ... train_dataset=cppe5["train"], -... tokenizer=image_processor, +... image_processor=image_processor, ... ) >>> trainer.train() diff --git a/docs/source/en/tasks/semantic_segmentation.md b/docs/source/en/tasks/semantic_segmentation.md index e99499bbbbd..ba40ccba1ec 100644 --- a/docs/source/en/tasks/semantic_segmentation.md +++ b/docs/source/en/tasks/semantic_segmentation.md @@ -642,7 +642,7 @@ and use the [`PushToHubCallback`] to upload the model: ... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"] ... ) ->>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor) +>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor) >>> callbacks = [metric_callback, push_to_hub_callback] ``` diff --git a/docs/source/en/tasks/video_classification.md b/docs/source/en/tasks/video_classification.md index 38bdceba41b..a0f0a695f70 100644 --- a/docs/source/en/tasks/video_classification.md +++ b/docs/source/en/tasks/video_classification.md @@ -407,7 +407,7 @@ Then you just pass all of this along with the datasets to `Trainer`: ... args, ... train_dataset=train_dataset, ... eval_dataset=val_dataset, -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... data_collator=collate_fn, ... ) diff --git a/docs/source/es/tasks/image_classification.md b/docs/source/es/tasks/image_classification.md index f09730caf69..4a572d81698 100644 --- a/docs/source/es/tasks/image_classification.md +++ b/docs/source/es/tasks/image_classification.md @@ -160,7 +160,7 @@ Al llegar a este punto, solo quedan tres pasos: ... data_collator=data_collator, ... train_dataset=food["train"], ... eval_dataset=food["test"], -... tokenizer=image_processor, +... image_processor=image_processor, ... ) >>> trainer.train() diff --git a/docs/source/ja/tasks/image_classification.md b/docs/source/ja/tasks/image_classification.md index f8d8d0d5523..fc57cf4dfb9 100644 --- a/docs/source/ja/tasks/image_classification.md +++ b/docs/source/ja/tasks/image_classification.md @@ -328,7 +328,7 @@ food["test"].set_transform(preprocess_val) ... data_collator=data_collator, ... train_dataset=food["train"], ... eval_dataset=food["test"], -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... ) @@ -426,7 +426,7 @@ Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Data >>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset) >>> push_to_hub_callback = PushToHubCallback( ... output_dir="food_classifier", -... tokenizer=image_processor, +... image_processor=image_processor, ... save_strategy="no", ... ) >>> callbacks = [metric_callback, push_to_hub_callback] diff --git a/docs/source/ja/tasks/object_detection.md b/docs/source/ja/tasks/object_detection.md index 389e7bdf2f4..e90cb4645a1 100644 --- a/docs/source/ja/tasks/object_detection.md +++ b/docs/source/ja/tasks/object_detection.md @@ -376,7 +376,7 @@ DETR モデルをトレーニングできる「ラベル」。画像プロセッ ... args=training_args, ... data_collator=collate_fn, ... train_dataset=cppe5["train"], -... tokenizer=image_processor, +... image_processor=image_processor, ... ) >>> trainer.train() diff --git a/docs/source/ja/tasks/semantic_segmentation.md b/docs/source/ja/tasks/semantic_segmentation.md index 2816688b4e1..bc4c8fdc103 100644 --- a/docs/source/ja/tasks/semantic_segmentation.md +++ b/docs/source/ja/tasks/semantic_segmentation.md @@ -434,7 +434,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。 ... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"] ... ) ->>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor) +>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor) >>> callbacks = [metric_callback, push_to_hub_callback] ``` diff --git a/docs/source/ja/tasks/sequence_classification.md b/docs/source/ja/tasks/sequence_classification.md index 6673cfe9e56..767d5e03cdf 100644 --- a/docs/source/ja/tasks/sequence_classification.md +++ b/docs/source/ja/tasks/sequence_classification.md @@ -436,7 +436,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。 ... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"] ... ) ->>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor) +>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor) >>> callbacks = [metric_callback, push_to_hub_callback] ``` diff --git a/docs/source/ja/tasks/video_classification.md b/docs/source/ja/tasks/video_classification.md index e0c38361941..b0b5139028b 100644 --- a/docs/source/ja/tasks/video_classification.md +++ b/docs/source/ja/tasks/video_classification.md @@ -414,7 +414,7 @@ def compute_metrics(eval_pred): ... args, ... train_dataset=train_dataset, ... eval_dataset=val_dataset, -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... data_collator=collate_fn, ... ) diff --git a/docs/source/ko/tasks/image_classification.md b/docs/source/ko/tasks/image_classification.md index 031e01ea5c5..055100d4c0b 100644 --- a/docs/source/ko/tasks/image_classification.md +++ b/docs/source/ko/tasks/image_classification.md @@ -321,7 +321,7 @@ food["test"].set_transform(preprocess_val) ... data_collator=data_collator, ... train_dataset=food["train"], ... eval_dataset=food["test"], -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... ) @@ -417,7 +417,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요: >>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset) >>> push_to_hub_callback = PushToHubCallback( ... output_dir="food_classifier", -... tokenizer=image_processor, +... image_processor=image_processor, ... save_strategy="no", ... ) >>> callbacks = [metric_callback, push_to_hub_callback] diff --git a/docs/source/ko/tasks/object_detection.md b/docs/source/ko/tasks/object_detection.md index 0076bba6f84..1eeada9a50e 100644 --- a/docs/source/ko/tasks/object_detection.md +++ b/docs/source/ko/tasks/object_detection.md @@ -366,7 +366,7 @@ DatasetDict({ ... args=training_args, ... data_collator=collate_fn, ... train_dataset=cppe5["train"], -... tokenizer=image_processor, +... image_processor=image_processor, ... ) >>> trainer.train() diff --git a/docs/source/ko/tasks/semantic_segmentation.md b/docs/source/ko/tasks/semantic_segmentation.md index 4b6109d692b..4c23b2ad80e 100644 --- a/docs/source/ko/tasks/semantic_segmentation.md +++ b/docs/source/ko/tasks/semantic_segmentation.md @@ -424,7 +424,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요: ... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"] ... ) ->>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor) +>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor) >>> callbacks = [metric_callback, push_to_hub_callback] ``` diff --git a/docs/source/ko/tasks/video_classification.md b/docs/source/ko/tasks/video_classification.md index 01dbb0757b6..4d13f9ac610 100644 --- a/docs/source/ko/tasks/video_classification.md +++ b/docs/source/ko/tasks/video_classification.md @@ -411,7 +411,7 @@ def compute_metrics(eval_pred): ... args, ... train_dataset=train_dataset, ... eval_dataset=val_dataset, -... tokenizer=image_processor, +... image_processor=image_processor, ... compute_metrics=compute_metrics, ... data_collator=collate_fn, ... ) diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index ff01600cb32..1c952e56014 100755 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -411,7 +411,7 @@ def main(): train_dataset=dataset["train"] if training_args.do_train else None, eval_dataset=dataset["validation"] if training_args.do_eval else None, compute_metrics=compute_metrics, - tokenizer=image_processor, + image_processor=image_processor, data_collator=collate_fn, ) diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index a23e41df611..0f098caf023 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -369,7 +369,7 @@ def main(): args=training_args, train_dataset=ds["train"] if training_args.do_train else None, eval_dataset=ds["validation"] if training_args.do_eval else None, - tokenizer=image_processor, + image_processor=image_processor, data_collator=collate_fn, ) diff --git a/examples/pytorch/image-pretraining/run_mim.py b/examples/pytorch/image-pretraining/run_mim.py index 625a96f14e5..e1afeece12c 100644 --- a/examples/pytorch/image-pretraining/run_mim.py +++ b/examples/pytorch/image-pretraining/run_mim.py @@ -458,7 +458,7 @@ def main(): args=training_args, train_dataset=ds["train"] if training_args.do_train else None, eval_dataset=ds["validation"] if training_args.do_eval else None, - tokenizer=image_processor, + image_processor=image_processor, data_collator=collate_fn, ) diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py index 957b78b9b56..8324531ccb0 100644 --- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py +++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py @@ -510,7 +510,7 @@ def main(): train_dataset=dataset["train"] if training_args.do_train else None, eval_dataset=dataset["validation"] if training_args.do_eval else None, compute_metrics=compute_metrics, - tokenizer=image_processor, + image_processor=image_processor, data_collator=default_data_collator, ) diff --git a/examples/tensorflow/image-classification/run_image_classification.py b/examples/tensorflow/image-classification/run_image_classification.py index 3e2b43bca10..ab2de73a3b8 100644 --- a/examples/tensorflow/image-classification/run_image_classification.py +++ b/examples/tensorflow/image-classification/run_image_classification.py @@ -552,7 +552,7 @@ def main(): output_dir=training_args.output_dir, hub_model_id=push_to_hub_model_id, hub_token=training_args.push_to_hub_token, - tokenizer=image_processor, + image_processor=image_processor, **model_card_kwargs, ) ) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6bcf4796f8d..436165b0e3d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,6 +59,7 @@ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .image_processing_utils import BaseImageProcessor from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary @@ -303,6 +304,9 @@ class Trainer: The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. + image_processor ([`BaseImageProcessor`], *optional*): + The image processor used to preprocess the data. If provided, it will be saved along the model to make it easier + to rerun an interrupted training or reuse the fine-tuned model. model_init (`Callable[[], PreTrainedModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start from a new instance of the model as given by this function. @@ -357,6 +361,7 @@ class Trainer: train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, + image_processor: Optional["BaseImageProcessor"] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, @@ -485,11 +490,12 @@ class Trainer: ): self.place_model_on_device = False - default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer + self.image_processor = image_processor # Bnb Quantized models doesn't support `.to` operation. if ( @@ -541,7 +547,7 @@ class Trainer: default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( - callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + callbacks, self.model, self.tokenizer, self.image_processor, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) @@ -3276,6 +3282,8 @@ class Trainer: ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir) + if self.image_processor is not None and self.args.should_save: + self.image_processor.save_pretrained(output_dir) # We moved the model from TPU -> CPU for saving the weights. # Now we should move it back to subsequent compute still works. @@ -3313,6 +3321,8 @@ class Trainer: if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) + if self.image_processor is not None: + self.image_processor.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -4009,6 +4019,9 @@ class Trainer: # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) + # Same for the image processor + if self.image_processor is not None: + self.image_processor.save_pretrained(output_dir) # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -4056,7 +4069,7 @@ class Trainer: def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: """ - Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`. + Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`. Parameters: commit_message (`str`, *optional*, defaults to `"End of training"`): diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 1e3b0e587a7..a9cb6eca596 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -189,6 +189,8 @@ class TrainerCallback: The model being trained. tokenizer ([`PreTrainedTokenizer`]): The tokenizer used for encoding the data. + image_processor ([`BaseImageProcessor`]): + The image processor used for encoding the images. optimizer (`torch.optim.Optimizer`): The optimizer used for the training steps. lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`): @@ -307,12 +309,13 @@ class TrainerCallback: class CallbackHandler(TrainerCallback): """Internal class that just calls the list of callbacks in order.""" - def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler): + def __init__(self, callbacks, model, tokenizer, image_processor, optimizer, lr_scheduler): self.callbacks = [] for cb in callbacks: self.add_callback(cb) self.model = model self.tokenizer = tokenizer + self.image_processor = image_processor self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.train_dataloader = None @@ -417,6 +420,7 @@ class CallbackHandler(TrainerCallback): control, model=self.model, tokenizer=self.tokenizer, + image_processor=self.image_processor, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, train_dataloader=self.train_dataloader,