Fixes to TF collators (#21143)

* Add num_workers for prepare_tf_dataset

* Bugfix in the default collator and change default tensor type

* Remove the "num_workers" arg and move it to a new PR
This commit is contained in:
Matt 2023-01-17 12:18:56 +00:00 committed by GitHub
parent 2411f0e465
commit e5dcceb82c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -159,7 +159,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
label_col_name = None
if label_col_name is not None:
if isinstance(first[label_col_name], tf.Tensor):
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
elif isinstance(first[label_col_name], (tuple, list)):

View File

@ -1345,9 +1345,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if collate_fn is None:
if tokenizer is None:
collate_fn = DefaultDataCollator(return_tensors="tf")
collate_fn = DefaultDataCollator(return_tensors="np")
else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
if collate_fn_args is None:
collate_fn_args = dict()