mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Handle Trainer tokenizer
kwarg deprecation with decorator (#33887)
* Handle deprecation with decorator * Fix for seq2seq Trainer
This commit is contained in:
parent
ee71c9853a
commit
2f25ab95db
@ -177,6 +177,7 @@ from .utils import (
|
||||
logging,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
@ -326,11 +327,6 @@ class Trainer:
|
||||
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
|
||||
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
|
||||
dataset prepending the dictionary key to the metric name.
|
||||
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
|
||||
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
|
||||
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
||||
interrupted training or reuse the fine-tuned model.
|
||||
This is now deprecated.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
@ -385,6 +381,7 @@ class Trainer:
|
||||
# Those are used as methods of the Trainer in examples.
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
@ -392,7 +389,6 @@ class Trainer:
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
@ -437,17 +433,6 @@ class Trainer:
|
||||
# force device and distributed setup init explicitly
|
||||
args._setup_devices
|
||||
|
||||
if tokenizer is not None:
|
||||
if processing_class is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `tokenizer` and `processing_class` at the same time. Please use `processing_class`."
|
||||
)
|
||||
warnings.warn(
|
||||
"`tokenizer` is now deprecated and will be removed in v5, please use `processing_class` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
processing_class = tokenizer
|
||||
|
||||
if model is None:
|
||||
if model_init is not None:
|
||||
self.model_init = model_init
|
||||
|
@ -25,6 +25,7 @@ from .generation.configuration_utils import GenerationConfig
|
||||
from .integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from .trainer import Trainer
|
||||
from .utils import logging
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -43,6 +44,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union["PreTrainedModel", nn.Module] = None,
|
||||
@ -50,7 +52,6 @@ class Seq2SeqTrainer(Trainer):
|
||||
data_collator: Optional["DataCollator"] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
processing_class: Optional[
|
||||
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
|
||||
] = None,
|
||||
@ -66,7 +67,6 @@ class Seq2SeqTrainer(Trainer):
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
|
Loading…
Reference in New Issue
Block a user