mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Removed from_config
This commit is contained in:
parent
1ca52567a4
commit
8938b546bf
@ -37,11 +37,6 @@ class Pipeline(ABC):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Provided path ({}) should be a directory".format(save_directory))
|
||||
@ -63,6 +58,12 @@ class Pipeline(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FeatureExtractionPipeline(Pipeline):
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class TextClassificationPipeline(Pipeline):
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer, nb_classes: int = 2):
|
||||
super().__init__(model, tokenizer)
|
||||
@ -71,10 +72,6 @@ class TextClassificationPipeline(Pipeline):
|
||||
raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes))
|
||||
self._nb_classes = nb_classes
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
|
||||
return cls(model, tokenizer, **kwargs)
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
if 'X' in kwargs and not texts:
|
||||
@ -102,10 +99,6 @@ class NerPipeline(Pipeline):
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(model, tokenizer)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
(texts, ), answers = texts, []
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user