mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added FeatureExtraction pipeline.
This commit is contained in:
parent
f1971bf303
commit
8e3b1c860f
@ -143,7 +143,23 @@ class JsonPipelineDataFormat(PipelineDataFormat):
|
||||
class FeatureExtractionPipeline(Pipeline):
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
pass
|
||||
# Generic compatibility with sklearn and Keras
|
||||
if 'X' in kwargs and not texts:
|
||||
texts = kwargs.pop('X')
|
||||
|
||||
inputs = self.tokenizer.batch_encode_plus(
|
||||
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
# TODO trace model
|
||||
predictions = self.model(inputs)[0]
|
||||
else:
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
predictions = self.model(**inputs)[0]
|
||||
|
||||
return predictions.numpy().tolist()
|
||||
|
||||
|
||||
class TextClassificationPipeline(Pipeline):
|
||||
@ -424,6 +440,11 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
# Register all the supported task here
|
||||
SUPPORTED_TASKS = {
|
||||
'feature-extraction': {
|
||||
'impl': FeatureExtractionPipeline,
|
||||
'tf': TFAutoModel if is_tf_available() else None,
|
||||
'pt': AutoModel if is_torch_available() else None,
|
||||
},
|
||||
'text-classification': {
|
||||
'impl': TextClassificationPipeline,
|
||||
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
|
Loading…
Reference in New Issue
Block a user