Added FeatureExtraction pipeline.

This commit is contained in:
Morgan Funtowicz 2019-12-15 01:37:52 +01:00
parent f1971bf303
commit 8e3b1c860f

View File

@ -143,7 +143,23 @@ class JsonPipelineDataFormat(PipelineDataFormat):
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs):
pass
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
if is_tf_available():
# TODO trace model
predictions = self.model(inputs)[0]
else:
import torch
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
class TextClassificationPipeline(Pipeline):
@ -424,6 +440,11 @@ class QuestionAnsweringPipeline(Pipeline):
# Register all the supported task here
SUPPORTED_TASKS = {
'feature-extraction': {
'impl': FeatureExtractionPipeline,
'tf': TFAutoModel if is_tf_available() else None,
'pt': AutoModel if is_torch_available() else None,
},
'text-classification': {
'impl': TextClassificationPipeline,
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,