mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Complete DataProcessor class
This commit is contained in:
parent
c45d0cf60f
commit
1efc208ff3
@ -93,6 +93,33 @@ class InputFeatures(object):
|
|||||||
class DataProcessor(object):
|
class DataProcessor(object):
|
||||||
"""Base class for data converters for sequence classification data sets."""
|
"""Base class for data converters for sequence classification data sets."""
|
||||||
|
|
||||||
|
def get_example_from_tensor_dict(self, tensor_dict):
|
||||||
|
"""Gets an example from a dict with tensorflow tensors
|
||||||
|
Args:
|
||||||
|
tensor_dict: Keys and values should match the corresponding Glue
|
||||||
|
tensorflow_dataset examples.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the train set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the dev set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""Gets the list of labels for this data set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def tfds_map(self, example):
|
||||||
|
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
|
||||||
|
This method converts examples to the correct format."""
|
||||||
|
if len(self.get_labels()) > 1:
|
||||||
|
example.label = self.get_labels()[int(example.label)]
|
||||||
|
return example
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _read_tsv(cls, input_file, quotechar=None):
|
def _read_tsv(cls, input_file, quotechar=None):
|
||||||
"""Reads a tab separated value file."""
|
"""Reads a tab separated value file."""
|
||||||
|
Loading…
Reference in New Issue
Block a user