mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[Trainer] Allow passing image processor (#29896)
* Add image processor to trainer * Replace tokenizer=image_processor everywhere
This commit is contained in:
parent
d704c0b698
commit
1ab7136488
@ -322,7 +322,7 @@ At this point, only three steps remain:
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
@ -418,7 +418,7 @@ and use the [PushToHubCallback](../main_classes/keras_callbacks#transformers.Pus
|
||||
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
|
||||
>>> push_to_hub_callback = PushToHubCallback(
|
||||
... output_dir="food_classifier",
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... save_strategy="no",
|
||||
... )
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
|
@ -384,7 +384,7 @@ Finally, bring everything together, and call [`~transformers.Trainer.train`]:
|
||||
... args=training_args,
|
||||
... data_collator=collate_fn,
|
||||
... train_dataset=cppe5["train"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
|
@ -642,7 +642,7 @@ and use the [`PushToHubCallback`] to upload the model:
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
@ -407,7 +407,7 @@ Then you just pass all of this along with the datasets to `Trainer`:
|
||||
... args,
|
||||
... train_dataset=train_dataset,
|
||||
... eval_dataset=val_dataset,
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... data_collator=collate_fn,
|
||||
... )
|
||||
|
@ -160,7 +160,7 @@ Al llegar a este punto, solo quedan tres pasos:
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
|
@ -328,7 +328,7 @@ food["test"].set_transform(preprocess_val)
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
@ -426,7 +426,7 @@ Convert your datasets to the `tf.data.Dataset` format using the [`~datasets.Data
|
||||
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
|
||||
>>> push_to_hub_callback = PushToHubCallback(
|
||||
... output_dir="food_classifier",
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... save_strategy="no",
|
||||
... )
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
|
@ -376,7 +376,7 @@ DETR モデルをトレーニングできる「ラベル」。画像プロセッ
|
||||
... args=training_args,
|
||||
... data_collator=collate_fn,
|
||||
... train_dataset=cppe5["train"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
|
@ -434,7 +434,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
@ -436,7 +436,7 @@ TensorFlow でモデルを微調整するには、次の手順に従います。
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
@ -414,7 +414,7 @@ def compute_metrics(eval_pred):
|
||||
... args,
|
||||
... train_dataset=train_dataset,
|
||||
... eval_dataset=val_dataset,
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... data_collator=collate_fn,
|
||||
... )
|
||||
|
@ -321,7 +321,7 @@ food["test"].set_transform(preprocess_val)
|
||||
... data_collator=data_collator,
|
||||
... train_dataset=food["train"],
|
||||
... eval_dataset=food["test"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
@ -417,7 +417,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요:
|
||||
>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
|
||||
>>> push_to_hub_callback = PushToHubCallback(
|
||||
... output_dir="food_classifier",
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... save_strategy="no",
|
||||
... )
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
|
@ -366,7 +366,7 @@ DatasetDict({
|
||||
... args=training_args,
|
||||
... data_collator=collate_fn,
|
||||
... train_dataset=cppe5["train"],
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
|
@ -424,7 +424,7 @@ TensorFlow에서 모델을 미세 조정하려면 다음 단계를 따르세요:
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
@ -411,7 +411,7 @@ def compute_metrics(eval_pred):
|
||||
... args,
|
||||
... train_dataset=train_dataset,
|
||||
... eval_dataset=val_dataset,
|
||||
... tokenizer=image_processor,
|
||||
... image_processor=image_processor,
|
||||
... compute_metrics=compute_metrics,
|
||||
... data_collator=collate_fn,
|
||||
... )
|
||||
|
@ -411,7 +411,7 @@ def main():
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=image_processor,
|
||||
image_processor=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
@ -369,7 +369,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
tokenizer=image_processor,
|
||||
image_processor=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
@ -458,7 +458,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
tokenizer=image_processor,
|
||||
image_processor=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
@ -510,7 +510,7 @@ def main():
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=image_processor,
|
||||
image_processor=image_processor,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
|
||||
|
@ -552,7 +552,7 @@ def main():
|
||||
output_dir=training_args.output_dir,
|
||||
hub_model_id=push_to_hub_model_id,
|
||||
hub_token=training_args.push_to_hub_token,
|
||||
tokenizer=image_processor,
|
||||
image_processor=image_processor,
|
||||
**model_card_kwargs,
|
||||
)
|
||||
)
|
||||
|
@ -59,6 +59,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||
from .image_processing_utils import BaseImageProcessor
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||
from .integrations.tpu import tpu_spmd_dataloader
|
||||
from .modelcard import TrainingSummary
|
||||
@ -303,6 +304,9 @@ class Trainer:
|
||||
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
|
||||
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
||||
interrupted training or reuse the fine-tuned model.
|
||||
image_processor ([`BaseImageProcessor`], *optional*):
|
||||
The image processor used to preprocess the data. If provided, it will be saved along the model to make it easier
|
||||
to rerun an interrupted training or reuse the fine-tuned model.
|
||||
model_init (`Callable[[], PreTrainedModel]`, *optional*):
|
||||
A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
|
||||
from a new instance of the model as given by this function.
|
||||
@ -357,6 +361,7 @@ class Trainer:
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
image_processor: Optional["BaseImageProcessor"] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
@ -485,11 +490,12 @@ class Trainer:
|
||||
):
|
||||
self.place_model_on_device = False
|
||||
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = image_processor
|
||||
|
||||
# Bnb Quantized models doesn't support `.to` operation.
|
||||
if (
|
||||
@ -541,7 +547,7 @@ class Trainer:
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
callbacks, self.model, self.tokenizer, self.image_processor, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
|
||||
@ -3276,6 +3282,8 @@ class Trainer:
|
||||
)
|
||||
if self.tokenizer is not None and self.args.should_save:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
if self.image_processor is not None and self.args.should_save:
|
||||
self.image_processor.save_pretrained(output_dir)
|
||||
|
||||
# We moved the model from TPU -> CPU for saving the weights.
|
||||
# Now we should move it back to subsequent compute still works.
|
||||
@ -3313,6 +3321,8 @@ class Trainer:
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
if self.image_processor is not None:
|
||||
self.image_processor.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
@ -4009,6 +4019,9 @@ class Trainer:
|
||||
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
# Same for the image processor
|
||||
if self.image_processor is not None:
|
||||
self.image_processor.save_pretrained(output_dir)
|
||||
# Same for the training arguments
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@ -4056,7 +4069,7 @@ class Trainer:
|
||||
|
||||
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||
Upload `self.model` and `self.tokenizer` or `self.image_processor` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||
|
||||
Parameters:
|
||||
commit_message (`str`, *optional*, defaults to `"End of training"`):
|
||||
|
@ -189,6 +189,8 @@ class TrainerCallback:
|
||||
The model being trained.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer used for encoding the data.
|
||||
image_processor ([`BaseImageProcessor`]):
|
||||
The image processor used for encoding the images.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer used for the training steps.
|
||||
lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
|
||||
@ -307,12 +309,13 @@ class TrainerCallback:
|
||||
class CallbackHandler(TrainerCallback):
|
||||
"""Internal class that just calls the list of callbacks in order."""
|
||||
|
||||
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
|
||||
def __init__(self, callbacks, model, tokenizer, image_processor, optimizer, lr_scheduler):
|
||||
self.callbacks = []
|
||||
for cb in callbacks:
|
||||
self.add_callback(cb)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = image_processor
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dataloader = None
|
||||
@ -417,6 +420,7 @@ class CallbackHandler(TrainerCallback):
|
||||
control,
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
image_processor=self.image_processor,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
train_dataloader=self.train_dataloader,
|
||||
|
Loading…
Reference in New Issue
Block a user