Removed from_config

This commit is contained in:
Morgan Funtowicz 2019-12-13 14:27:04 +01:00
parent 1ca52567a4
commit 8938b546bf

View File

@ -37,11 +37,6 @@ class Pipeline(ABC):
self.model = model
self.tokenizer = tokenizer
@classmethod
@abstractmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
raise NotImplementedError()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
@ -63,6 +58,12 @@ class Pipeline(ABC):
raise NotImplementedError()
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs):
pass
class TextClassificationPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer, nb_classes: int = 2):
super().__init__(model, tokenizer)
@ -71,10 +72,6 @@ class TextClassificationPipeline(Pipeline):
raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes))
self._nb_classes = nb_classes
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
return cls(model, tokenizer, **kwargs)
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
@ -102,10 +99,6 @@ class NerPipeline(Pipeline):
def __init__(self, model, tokenizer: PreTrainedTokenizer):
super().__init__(model, tokenizer)
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
def __call__(self, *texts, **kwargs):
(texts, ), answers = texts, []