mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixes to TF collators (#21143)
* Add num_workers for prepare_tf_dataset * Bugfix in the default collator and change default tensor type * Remove the "num_workers" arg and move it to a new PR
This commit is contained in:
parent
2411f0e465
commit
e5dcceb82c
@ -159,7 +159,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
||||
label_col_name = None
|
||||
if label_col_name is not None:
|
||||
if isinstance(first[label_col_name], tf.Tensor):
|
||||
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
|
||||
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
|
||||
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
||||
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
||||
elif isinstance(first[label_col_name], (tuple, list)):
|
||||
|
@ -1345,9 +1345,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
if collate_fn is None:
|
||||
if tokenizer is None:
|
||||
collate_fn = DefaultDataCollator(return_tensors="tf")
|
||||
collate_fn = DefaultDataCollator(return_tensors="np")
|
||||
else:
|
||||
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
|
||||
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
|
||||
if collate_fn_args is None:
|
||||
collate_fn_args = dict()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user