diff --git a/transformers/__init__.py b/transformers/__init__.py index a133425a9cf..b29ad38e735 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -29,7 +29,7 @@ from .data import (is_sklearn_available, xnli_output_modes, xnli_processors, xnli_tasks_num_labels) if is_sklearn_available(): - from .data import glue_compute_metrics + from .data import glue_compute_metrics, xnli_compute_metrics # Tokenizers from .tokenization_utils import (PreTrainedTokenizer) diff --git a/transformers/data/__init__.py b/transformers/data/__init__.py index 46615608a4d..b811a358079 100644 --- a/transformers/data/__init__.py +++ b/transformers/data/__init__.py @@ -4,4 +4,4 @@ from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_label from .metrics import is_sklearn_available if is_sklearn_available(): - from .metrics import glue_compute_metrics + from .metrics import glue_compute_metrics, xnli_compute_metrics diff --git a/transformers/data/metrics/__init__.py b/transformers/data/metrics/__init__.py index c9ebaac38df..5a46eb05d3b 100644 --- a/transformers/data/metrics/__init__.py +++ b/transformers/data/metrics/__init__.py @@ -81,3 +81,11 @@ if _has_sklearn: return {"acc": simple_accuracy(preds, labels)} else: raise KeyError(task_name) + + + def xnli_compute_metrics(task_name, preds, labels): + assert len(preds) == len(labels) + if task_name == "xnli": + return {"acc": simple_accuracy(preds, labels)} + else: + raise KeyError(task_name) diff --git a/transformers/data/processors/xnli.py b/transformers/data/processors/xnli.py index a4807dd901c..ce582f31a68 100644 --- a/transformers/data/processors/xnli.py +++ b/transformers/data/processors/xnli.py @@ -73,13 +73,6 @@ class XnliProcessor(DataProcessor): """See base class.""" return ["contradiction", "entailment", "neutral"] -def xnli_compute_metrics(task_name, preds, labels): - assert len(preds) == len(labels) - if task_name == "xnli": - return {"acc": simple_accuracy(preds, labels)} - else: - raise ValueError('{} is not a supported task.'.format(task_name)) - xnli_processors = { "xnli": XnliProcessor, }