mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)
* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels * Add test for numpy scalar inputs
This commit is contained in:
parent
4a6a35bc65
commit
f00f22a3e2
@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if "label" in first and first["label"] is not None:
|
||||
if isinstance(first["label"], tf.Tensor):
|
||||
dtype = tf.int64 if first["label"].dtype.is_integer() else tf.float32
|
||||
elif isinstance(first["label"], np.ndarray):
|
||||
dtype = tf.int64 if np.issubdtype(first["label"].dtype, np.integer) else tf.float32
|
||||
elif isinstance(first["label"], (tuple, list)):
|
||||
dtype = tf.int64 if isinstance(first["label"][0], int) else tf.float32
|
||||
else:
|
||||
dtype = tf.int64 if isinstance(first["label"], int) else tf.float32
|
||||
batch["labels"] = tf.convert_to_tensor([f["label"] for f in features], dtype=dtype)
|
||||
label_col_name = "label"
|
||||
elif "label_ids" in first and first["label_ids"] is not None:
|
||||
if isinstance(first["label_ids"], tf.Tensor):
|
||||
batch["labels"] = tf.stack([f["label_ids"] for f in features])
|
||||
label_col_name = "label_ids"
|
||||
elif "labels" in first and first["labels"] is not None:
|
||||
label_col_name = "labels"
|
||||
else:
|
||||
label_col_name = None
|
||||
if label_col_name is not None:
|
||||
if isinstance(first[label_col_name], tf.Tensor):
|
||||
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
|
||||
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
||||
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
||||
elif isinstance(first[label_col_name], (tuple, list)):
|
||||
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
||||
else:
|
||||
dtype = tf.int64 if type(first["label_ids"][0]) is int else tf.float32
|
||||
batch["labels"] = tf.convert_to_tensor([f["label_ids"] for f in features], dtype=dtype)
|
||||
|
||||
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
|
||||
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
|
||||
# Handling of all other possible keys.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in first.items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
|
||||
if isinstance(v, (tf.Tensor, np.ndarray)):
|
||||
batch[k] = tf.stack([f[k] for f in features])
|
||||
else:
|
||||
|
@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
||||
|
||||
def test_numpy_dtype_preservation(self):
|
||||
data_collator = default_data_collator
|
||||
|
||||
# Confirms that numpy inputs are handled correctly even when scalars
|
||||
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
|
||||
batch = data_collator(features, return_tensors="tf")
|
||||
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||
|
||||
def test_default_classification_and_regression(self):
|
||||
data_collator = default_data_collator
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user