Initial bunch of documentation.

This commit is contained in:
Morgan Funtowicz 2019-12-17 12:16:07 +01:00
parent d7c62661a3
commit 2fde5a2489

View File

@ -80,6 +80,15 @@ class _ScikitCompat(ABC):
class PipelineDataFormat:
"""
Base class for all the pipeline supported data format both for reading and writing.
Supported data formats currently includes:
- JSON
- CSV
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
"""
SUPPORTED_FORMATS = ['json', 'csv']
def __init__(self, output: str, path: str, column: str):
@ -138,7 +147,6 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: str, path: str, column: str):
super().__init__(output, path, column)
@ -158,6 +166,11 @@ class JsonPipelineDataFormat(PipelineDataFormat):
class Pipeline(_ScikitCompat):
"""
Base class implementing pipelined operations.
Pipeline workflow is defined as a sequence of the following operations:
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
args_parser: ArgumentHandler = None, device: int = -1, **kwargs):
@ -171,6 +184,9 @@ class Pipeline(_ScikitCompat):
self.model = self.model.to('cuda:{}'.format(self.device))
def save_pretrained(self, save_directory):
"""
Save the pipeline's model and tokenizer to the specified save_directory
"""
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
return
@ -179,9 +195,16 @@ class Pipeline(_ScikitCompat):
self.tokenizer.save_pretrained(save_directory)
def transform(self, X):
"""
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
"""
return self(X=X)
def predict(self, X):
"""
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__().
Se
"""
return self(X=X)
def __call__(self, *texts, **kwargs):
@ -198,6 +221,17 @@ class Pipeline(_ScikitCompat):
@contextmanager
def device_placement(self):
"""
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
example:
# Explicitly ask for tensor allocation on CUDA device :0
nlp = pipeline(..., device=0)
with nlp.device_placement():
# Every framework specific tensor allocation will be done on the request device
output = nlp(...)
Returns:
Context manager
"""
if is_tf_available():
import tensorflow as tf
with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)):
@ -210,6 +244,13 @@ class Pipeline(_ScikitCompat):
yield
def _forward(self, inputs):
"""
Internal framework specific forward dispatching.
Args:
inputs: dict holding all the keyworded arguments for required by the model forward method.
Returns:
Numpy array
"""
if is_tf_available():
# TODO trace model
predictions = self.model(inputs)[0]
@ -222,11 +263,17 @@ class Pipeline(_ScikitCompat):
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using Model head.
"""
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
class TextClassificationPipeline(Pipeline):
"""
Text classification pipeline using ModelForTextClassification head.
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer, nb_classes: int = 2):
super().__init__(model, tokenizer)
@ -239,7 +286,9 @@ class TextClassificationPipeline(Pipeline):
class NerPipeline(Pipeline):
"""
Named Entity Recognition pipeline using ModelForTokenClassification head.
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer):
super().__init__(model, tokenizer)
@ -286,7 +335,7 @@ class NerPipeline(Pipeline):
class QuestionAnsweringPipeline(Pipeline):
"""
Question Answering pipeline involving Tokenization and Inference.
Question Answering pipeline using ModelForQuestionAnswering head.
"""
class QuestionAnsweringArgumentHandler(ArgumentHandler):
@ -341,9 +390,15 @@ class QuestionAnsweringPipeline(Pipeline):
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
is_list = isinstance(question, list)
if is_list:
"""
QuestionAnsweringPipeline leverages the SquadExample/SquadFeatures internally.
This helper method encapsulate all the logic for converting question(s) and context(s) to SquadExample(s).
We currently support extractive question answering.
Args:
question: (str, List[str]) The question to be ask for the associated context
context: (str, List[str]) The context in which we will look for the answer.
"""
if isinstance(question, list):
return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
else:
return SquadExample(None, question, context, None, None, None)
@ -352,6 +407,12 @@ class QuestionAnsweringPipeline(Pipeline):
super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler())
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
"""
Generates the input dictionary with model-specific parameters.
Returns:
dict holding all the required parameters for model's forward
"""
args = ['input_ids', 'attention_mask']
model_type = type(self.model).__name__.lower()
@ -367,6 +428,20 @@ class QuestionAnsweringPipeline(Pipeline):
return {k: [feature.__dict__[k] for feature in features] for k in args}
def __call__(self, *texts, **kwargs):
"""
Args:
We support multiple use-cases, the following are exclusive:
X: sequence of SquadExample
data: sequence of SquadExample
question: (str, List[str]), batch of question(s) to map along with context
context: (str, List[str]), batch of context(s) associated with the provided question keyword argument
Returns:
dict: {'answer': str, 'score": float, 'start": int, "end": int}
answer: the textual answer in the intial context
score: the score the current answer scored for the model
start: the character index in the original string corresponding to the beginning of the answer' span
end: the character index in the original string corresponding to the ending of the answer' span
"""
# Set defaults values
kwargs.setdefault('topk', 1)
kwargs.setdefault('doc_stride', 128)
@ -432,6 +507,19 @@ class QuestionAnsweringPipeline(Pipeline):
return answers
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
"""
Take the output of any QuestionAnswering head and will generate probalities for each span to be
the actual answer.
In addition, it filters out some unwanted/impossible cases like answer len being greater than
max_answer_len or answer end position being before the starting position.
The method supports output the k-best answer through the topk argument.
Args:
start: numpy array, holding individual start probabilities for each token
end: numpy array, holding individual end probabilities for each token
topk: int, indicates how many possible answer span(s) to extract from the model's output
max_answer_len: int, maximum size of the answer to extract from the model's output
"""
# Ensure we have batch axis
if start.ndim == 1:
start = start[None]
@ -459,6 +547,18 @@ class QuestionAnsweringPipeline(Pipeline):
return start, end, candidates[0, start, end]
def span_to_answer(self, text: str, start: int, end: int):
"""
When decoding from token probalities, this method maps token indexes to actual word in
the initial context.
Args:
text: str, the actual context to extract the answer from
start: int, starting answer token index
end: int, ending answer token index
Returns:
dict: {'answer': str, 'start': int, 'end': int}
"""
words = []
token_idx = char_start_idx = char_end_idx = chars_idx = 0
@ -514,7 +614,11 @@ SUPPORTED_TASKS = {
def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
"""
Utility factory method to build pipeline.
Utility factory method to build a pipeline.
Pipeline are made of:
A Tokenizer instance in charge of mapping raw textual input to token
A Model instance
Some (optional) post processing for enhancing model's output
"""
# Try to infer tokenizer from model name (if provided as str)
if tokenizer is None: