diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 60522344d4f..5149ff39c84 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -159,7 +159,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: label_col_name = None if label_col_name is not None: if isinstance(first[label_col_name], tf.Tensor): - dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32 + dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32 elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic): dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32 elif isinstance(first[label_col_name], (tuple, list)): diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 2372984b71a..6572a0f8591 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1345,9 +1345,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu if collate_fn is None: if tokenizer is None: - collate_fn = DefaultDataCollator(return_tensors="tf") + collate_fn = DefaultDataCollator(return_tensors="np") else: - collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf") + collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np") if collate_fn_args is None: collate_fn_args = dict()