Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)

* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels

* Add test for numpy scalar inputs
This commit is contained in:
Matt 2022-01-20 14:26:51 +00:00 committed by GitHub
parent 4a6a35bc65
commit f00f22a3e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 15 deletions

View File

@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None:
if isinstance(first["label"], tf.Tensor):
dtype = tf.int64 if first["label"].dtype.is_integer() else tf.float32
elif isinstance(first["label"], np.ndarray):
dtype = tf.int64 if np.issubdtype(first["label"].dtype, np.integer) else tf.float32
elif isinstance(first["label"], (tuple, list)):
dtype = tf.int64 if isinstance(first["label"][0], int) else tf.float32
else:
dtype = tf.int64 if isinstance(first["label"], int) else tf.float32
batch["labels"] = tf.convert_to_tensor([f["label"] for f in features], dtype=dtype)
label_col_name = "label"
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], tf.Tensor):
batch["labels"] = tf.stack([f["label_ids"] for f in features])
label_col_name = "label_ids"
elif "labels" in first and first["labels"] is not None:
label_col_name = "labels"
else:
label_col_name = None
if label_col_name is not None:
if isinstance(first[label_col_name], tf.Tensor):
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
elif isinstance(first[label_col_name], (tuple, list)):
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
else:
dtype = tf.int64 if type(first["label_ids"][0]) is int else tf.float32
batch["labels"] = tf.convert_to_tensor([f["label_ids"] for f in features], dtype=dtype)
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
if isinstance(v, (tf.Tensor, np.ndarray)):
batch[k] = tf.stack([f[k] for f in features])
else:

View File

@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, tf.int64)
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
def test_numpy_dtype_preservation(self):
data_collator = default_data_collator
# Confirms that numpy inputs are handled correctly even when scalars
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
batch = data_collator(features, return_tensors="tf")
self.assertEqual(batch["labels"].dtype, tf.int64)
def test_default_classification_and_regression(self):
data_collator = default_data_collator