mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix KerasMetricCallback prediction with generate() and inference of column names (#15351)
* Fix prediction with generate() and the inference of column names Should now have very few differences with the PyTorch implementation * Minor edit to parent class * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Explaining the dict conversion * Putting main_input_name back * Fixes to main_input_name Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
da5ef25db9
commit
6beae766ee
@ -56,8 +56,6 @@ class KerasMetricCallback(Callback):
|
||||
metric names to numerical values.
|
||||
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
|
||||
Validation data to be used to generate predictions for the `metric_fn`.
|
||||
metric_fn_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to be passed to the metric_fn.
|
||||
output_cols (`List[str], *optional*):
|
||||
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
||||
label_cols ('`List[str]`, *optional*'):
|
||||
@ -74,7 +72,6 @@ class KerasMetricCallback(Callback):
|
||||
self,
|
||||
metric_fn: Callable,
|
||||
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
||||
metric_fn_kwargs: Optional[dict] = None,
|
||||
output_cols: Optional[List[str]] = None,
|
||||
label_cols: Optional[List[str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
@ -94,12 +91,6 @@ class KerasMetricCallback(Callback):
|
||||
self.eval_dataset = eval_dataset
|
||||
self.predict_with_generate = predict_with_generate
|
||||
self.output_cols = output_cols
|
||||
self.metric_fn_kwargs = metric_fn_kwargs or dict()
|
||||
|
||||
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||
self.main_input_name = self.model.encoder.main_input_name
|
||||
else:
|
||||
self.main_input_name = self.model.main_input_name
|
||||
|
||||
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
||||
# that is passed to the metric_fn
|
||||
@ -123,32 +114,75 @@ class KerasMetricCallback(Callback):
|
||||
self.label_cols = ["labels"]
|
||||
self.use_keras_label = False
|
||||
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
|
||||
elif "start_positions" in input_spec and "end_positions" in input_spec:
|
||||
self.label_cols = ["start_positions", "end_positions"]
|
||||
self.use_keras_label = False
|
||||
logging.warning(
|
||||
"No label_cols specified for KerasMetricCallback, assuming you want the "
|
||||
"start_positions and end_positions keys."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
|
||||
if parse(tf.__version__).minor < parse("2.7"):
|
||||
if parse(tf.__version__) < parse("2.7"):
|
||||
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
|
||||
|
||||
@staticmethod
|
||||
def _concatenate_batches(batches):
|
||||
# Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray
|
||||
return [sample for batch in batches for sample in batch]
|
||||
def _concatenate_batches(batches, padding_index=-100):
|
||||
# If all batches are unidimensional or same length, do a simple concatenation
|
||||
if batches[0].ndim == 1 or all([batch.shape[1] == batches[0].shape[1] for batch in batches]):
|
||||
return np.concatenate(batches, axis=0)
|
||||
|
||||
# Welp, they're not the same length. Let's do some padding
|
||||
max_len = max([batch.shape[1] for batch in batches])
|
||||
num_samples = sum([batch.shape[0] for batch in batches])
|
||||
output = np.full_like(
|
||||
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
|
||||
)
|
||||
# i keeps track of which part of the concatenated array we're writing the next batch to
|
||||
i = 0
|
||||
for batch in batches:
|
||||
output[i : i + len(batch), : batch.shape[1]] = batch
|
||||
i += len(batch)
|
||||
return output
|
||||
|
||||
def _postprocess_predictions_or_labels(self, inputs):
|
||||
if isinstance(inputs[0], dict):
|
||||
outputs = dict()
|
||||
for key in inputs[0].keys():
|
||||
outputs[key] = self._concatenate_batches(batch[key] for batch in inputs)
|
||||
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
|
||||
# If it's a dict with only one key, just return the array
|
||||
if len(outputs) == 1:
|
||||
outputs = list(outputs.values())[0]
|
||||
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
|
||||
outputs = []
|
||||
for input_list in zip(*inputs):
|
||||
outputs.append(self._concatenate_batches(input_list))
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0] # If it's a list with only one element, just return the array
|
||||
elif isinstance(inputs[0], np.ndarray):
|
||||
outputs = self._concatenate_batches(inputs)
|
||||
elif isinstance(inputs[0], tf.Tensor):
|
||||
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
|
||||
else:
|
||||
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
|
||||
return outputs
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
main_input_name = None
|
||||
if self.predict_with_generate:
|
||||
# This dense conditional recognizes the case where we have an encoder-decoder model, but
|
||||
# avoids getting tangled up when we just have a model with a layer called 'encoder'
|
||||
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
|
||||
if self.model.encoder.main_input_name != self.model.main_input_name:
|
||||
main_input_name = self.model.encoder.main_input_name
|
||||
else:
|
||||
main_input_name = getattr(self.model, "main_input_name", "input_ids")
|
||||
|
||||
prediction_list = []
|
||||
label_list = []
|
||||
|
||||
@ -160,7 +194,7 @@ class KerasMetricCallback(Callback):
|
||||
labels = None
|
||||
if self.predict_with_generate:
|
||||
if isinstance(batch, dict):
|
||||
generation_inputs = batch[self.main_input_name]
|
||||
generation_inputs = batch[main_input_name]
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
else:
|
||||
generation_inputs = batch
|
||||
@ -169,9 +203,14 @@ class KerasMetricCallback(Callback):
|
||||
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
|
||||
else:
|
||||
predictions = self.model.predict(batch)
|
||||
predictions = dict(predictions)
|
||||
if self.output_cols is not None:
|
||||
predictions = {key: predictions[key] for key in self.output_cols}
|
||||
if isinstance(predictions, dict):
|
||||
# This converts any dict-subclass to a regular dict
|
||||
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
|
||||
predictions = dict(predictions)
|
||||
if self.output_cols is not None:
|
||||
predictions = {key: predictions[key] for key in self.output_cols}
|
||||
else:
|
||||
predictions = {key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]}
|
||||
prediction_list.append(predictions)
|
||||
if not self.use_keras_label:
|
||||
labels = {key: batch[key].numpy() for key in self.label_cols}
|
||||
@ -185,10 +224,10 @@ class KerasMetricCallback(Callback):
|
||||
raise TypeError(f"Confused by labels of type {type(labels)}")
|
||||
label_list.append(labels)
|
||||
|
||||
prediction_list = self._postprocess_predictions_or_labels(prediction_list)
|
||||
label_list = self._postprocess_predictions_or_labels(label_list)
|
||||
all_preds = self._postprocess_predictions_or_labels(prediction_list)
|
||||
all_labels = self._postprocess_predictions_or_labels(label_list)
|
||||
|
||||
metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs)
|
||||
metric_output = self.metric_fn((all_preds, all_labels))
|
||||
if not isinstance(metric_output, dict):
|
||||
raise TypeError(
|
||||
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
|
||||
|
Loading…
Reference in New Issue
Block a user