mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[Trainer] Fix default data collator (#30142)
* Fix data collator * Support feature extractors as well
This commit is contained in:
parent
ec59a42192
commit
ba1b24e07b
@ -58,6 +58,7 @@ from . import __version__
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||||
|
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||||
from .image_processing_utils import BaseImageProcessor
|
from .image_processing_utils import BaseImageProcessor
|
||||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||||
@ -492,7 +493,11 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
self.place_model_on_device = False
|
self.place_model_on_device = False
|
||||||
|
|
||||||
default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator
|
default_collator = (
|
||||||
|
DataCollatorWithPadding(tokenizer)
|
||||||
|
if tokenizer is not None and isinstance(tokenizer, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
|
||||||
|
else default_data_collator
|
||||||
|
)
|
||||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
|
Loading…
Reference in New Issue
Block a user