diff --git a/src/transformers/data/processors/utils.py b/src/transformers/data/processors/utils.py index d16d72f6cb6..4cc931cdf9c 100644 --- a/src/transformers/data/processors/utils.py +++ b/src/transformers/data/processors/utils.py @@ -93,6 +93,33 @@ class InputFeatures(object): class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" + def get_example_from_tensor_dict(self, tensor_dict): + """Gets an example from a dict with tensorflow tensors + Args: + tensor_dict: Keys and values should match the corresponding Glue + tensorflow_dataset examples. + """ + raise NotImplementedError() + + def get_train_examples(self, data_dir): + """Gets a collection of `InputExample`s for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + + def get_labels(self): + """Gets the list of labels for this data set.""" + raise NotImplementedError() + + def tfds_map(self, example): + """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. + This method converts examples to the correct format.""" + if len(self.get_labels()) > 1: + example.label = self.get_labels()[int(example.label)] + return example + @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file."""