From 1fa93ca1eaa249321ef39994e9f022d0799034a3 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 20 Dec 2019 12:34:19 +0100 Subject: [PATCH] Clean up framework handling --- transformers/pipelines.py | 85 ++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/transformers/pipelines.py b/transformers/pipelines.py index be2b1db1265..1c56033f7ca 100755 --- a/transformers/pipelines.py +++ b/transformers/pipelines.py @@ -48,6 +48,19 @@ if is_torch_available(): logger = logging.getLogger(__name__) +def get_framework(model=None): + if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): + # Both framework are available but the use supplied a model class instance. + # Try to guess which framework to use from the model classname + framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt' + else: + framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None) + if framework is None: + raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. " + "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " + "To install PyTorch, read the instructions at https://pytorch.org/.") + return framework + class ArgumentHandler(ABC): """ Base interface for handling varargs for each Pipeline @@ -279,19 +292,23 @@ class Pipeline(_ScikitCompat): nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...') """ def __init__(self, model, tokenizer: PreTrainedTokenizer = None, - modelcard: ModelCard = None, + modelcard: ModelCard = None, framework: Optional[str] = None, args_parser: ArgumentHandler = None, device: int = -1, binary_output: bool = False): + if framework is None: + framework = get_framework() + self.model = model self.tokenizer = tokenizer self.modelcard = modelcard + self.framework = framework self.device = device self.binary_output = binary_output self._args_parser = args_parser or DefaultArgumentHandler() # Special handling - if self.device >= 0 and not is_tf_available(): + if self.device >= 0 and self.framework == 'pt': self.model = self.model.to('cuda:{}'.format(self.device)) def save_pretrained(self, save_directory): @@ -332,7 +349,7 @@ class Pipeline(_ScikitCompat): Returns: Context manager """ - if is_tf_available(): + if self.framework == 'tf': with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)): yield else: @@ -371,7 +388,7 @@ class Pipeline(_ScikitCompat): with self.device_placement(): inputs = self.tokenizer.batch_encode_plus( inputs, add_special_tokens=True, - return_tensors='tf' if is_tf_available() else 'pt', + return_tensors=self.framework, max_length=self.tokenizer.max_len ) @@ -387,7 +404,7 @@ class Pipeline(_ScikitCompat): Returns: Numpy array """ - if is_tf_available(): + if self.framework == 'tf': # TODO trace model predictions = self.model(inputs, training=False)[0] else: @@ -405,9 +422,16 @@ class FeatureExtractionPipeline(Pipeline): def __init__(self, model, tokenizer: PreTrainedTokenizer = None, modelcard: ModelCard = None, + framework: Optional[str] = None, args_parser: ArgumentHandler = None, device: int = -1): - super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output=True) + super().__init__(model=model, + tokenizer=tokenizer, + modelcard=modelcard, + framework=framework, + args_parser=args_parser, + device=device, + binary_output=True) def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs).tolist() @@ -430,10 +454,16 @@ class NerPipeline(Pipeline): """ def __init__(self, model, tokenizer: PreTrainedTokenizer = None, - modelcard: ModelCard = None, + modelcard: ModelCard = None, framework: Optional[str] = None, args_parser: ArgumentHandler = None, device: int = -1, binary_output: bool = False): - super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output) + super().__init__(model=model, + tokenizer=tokenizer, + modelcard=modelcard, + framework=framework, + args_parser=args_parser, + device=device, + binary_output=binary_output) self._basic_tokenizer = BasicTokenizer(do_lower_case=False) @@ -452,12 +482,12 @@ class NerPipeline(Pipeline): tokens = self.tokenizer.encode_plus( sentence, return_attention_mask=False, - return_tensors='tf' if is_tf_available() else 'pt', + return_tensors=self.framework, max_length=self.tokenizer.max_len ) # Forward - if is_tf_available(): + if self.framework == 'tf': entities = self.model(tokens)[0][0].numpy() else: with torch.no_grad(): @@ -549,6 +579,18 @@ class QuestionAnsweringPipeline(Pipeline): Question Answering pipeline using ModelForQuestionAnswering head. """ + def __init__(self, model, + tokenizer: Optional[PreTrainedTokenizer], + modelcard: Optional[ModelCard], + framework: Optional[str] = None, + device: int = -1, **kwargs): + super().__init__(model=model, + tokenizer=tokenizer, + modelcard=modelcard, + framework=framework, + args_parser=QuestionAnsweringArgumentHandler(), + device=device, **kwargs) + @staticmethod def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]: """ @@ -567,12 +609,6 @@ class QuestionAnsweringPipeline(Pipeline): else: return SquadExample(None, question, context, None, None, None) - def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], - modelcard: Optional[ModelCard], - device: int = -1, **kwargs): - super().__init__(model, tokenizer, modelcard, args_parser=QuestionAnsweringArgumentHandler(), - device=device, **kwargs) - def __call__(self, *texts, **kwargs): """ Args: @@ -608,7 +644,7 @@ class QuestionAnsweringPipeline(Pipeline): # Manage tensor allocation on correct device with self.device_placement(): - if is_tf_available(): + if self.framework == 'tf': fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()} start, end = self.model(fw_args) start, end = start.numpy(), end.numpy() @@ -798,15 +834,10 @@ def pipeline(task: str, model: Optional = None, if task not in SUPPORTED_TASKS: raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()))) - pipeline_framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None) - if pipeline_framework is None: - raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. " - "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " - "To install PyTorch, read the instructions at https://pytorch.org/.") - + framework = get_framework(model) targeted_task = SUPPORTED_TASKS[task] - task, model_class = targeted_task['impl'], targeted_task[pipeline_framework] + task, model_class = targeted_task['impl'], targeted_task[framework] # Use default model/config/tokenizer for the task if no model is provided if model is None: @@ -843,14 +874,14 @@ def pipeline(task: str, model: Optional = None, if isinstance(model, str): # Handle transparent TF/PT model conversion model_kwargs = {} - if pipeline_framework == 'pt' and model.endswith('.h5'): + if framework == 'pt' and model.endswith('.h5'): model_kwargs['from_tf'] = True logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. ' 'Trying to load the model with PyTorch.') - elif pipeline_framework == 'tf' and model.endswith('.bin'): + elif framework == 'tf' and model.endswith('.bin'): model_kwargs['from_pt'] = True logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. ' 'Trying to load the model with Tensorflow.') model = model_class.from_pretrained(model, config=config, **model_kwargs) - return task(model, tokenizer, **kwargs) + return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)