Clean up framework handling

This commit is contained in:
thomwolf 2019-12-20 12:34:19 +01:00
parent ca6bdb28f6
commit 1fa93ca1ea

View File

@ -48,6 +48,19 @@ if is_torch_available():
logger = logging.getLogger(__name__)
def get_framework(model=None):
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the use supplied a model class instance.
# Try to guess which framework to use from the model classname
framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt'
else:
framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
if framework is None:
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.")
return framework
class ArgumentHandler(ABC):
"""
Base interface for handling varargs for each Pipeline
@ -279,19 +292,23 @@ class Pipeline(_ScikitCompat):
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
if framework is None:
framework = get_framework()
self.model = model
self.tokenizer = tokenizer
self.modelcard = modelcard
self.framework = framework
self.device = device
self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling
if self.device >= 0 and not is_tf_available():
if self.device >= 0 and self.framework == 'pt':
self.model = self.model.to('cuda:{}'.format(self.device))
def save_pretrained(self, save_directory):
@ -332,7 +349,7 @@ class Pipeline(_ScikitCompat):
Returns:
Context manager
"""
if is_tf_available():
if self.framework == 'tf':
with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)):
yield
else:
@ -371,7 +388,7 @@ class Pipeline(_ScikitCompat):
with self.device_placement():
inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True,
return_tensors='tf' if is_tf_available() else 'pt',
return_tensors=self.framework,
max_length=self.tokenizer.max_len
)
@ -387,7 +404,7 @@ class Pipeline(_ScikitCompat):
Returns:
Numpy array
"""
if is_tf_available():
if self.framework == 'tf':
# TODO trace model
predictions = self.model(inputs, training=False)[0]
else:
@ -405,9 +422,16 @@ class FeatureExtractionPipeline(Pipeline):
def __init__(self, model,
tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1):
super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output=True)
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True)
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
@ -430,10 +454,16 @@ class NerPipeline(Pipeline):
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output)
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=binary_output)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
@ -452,12 +482,12 @@ class NerPipeline(Pipeline):
tokens = self.tokenizer.encode_plus(
sentence, return_attention_mask=False,
return_tensors='tf' if is_tf_available() else 'pt',
return_tensors=self.framework,
max_length=self.tokenizer.max_len
)
# Forward
if is_tf_available():
if self.framework == 'tf':
entities = self.model(tokens)[0][0].numpy()
else:
with torch.no_grad():
@ -549,6 +579,18 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head.
"""
def __init__(self, model,
tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard],
framework: Optional[str] = None,
device: int = -1, **kwargs):
super().__init__(model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device, **kwargs)
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
"""
@ -567,12 +609,6 @@ class QuestionAnsweringPipeline(Pipeline):
else:
return SquadExample(None, question, context, None, None, None)
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard],
device: int = -1, **kwargs):
super().__init__(model, tokenizer, modelcard, args_parser=QuestionAnsweringArgumentHandler(),
device=device, **kwargs)
def __call__(self, *texts, **kwargs):
"""
Args:
@ -608,7 +644,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Manage tensor allocation on correct device
with self.device_placement():
if is_tf_available():
if self.framework == 'tf':
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
start, end = self.model(fw_args)
start, end = start.numpy(), end.numpy()
@ -798,15 +834,10 @@ def pipeline(task: str, model: Optional = None,
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
pipeline_framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
if pipeline_framework is None:
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.")
framework = get_framework(model)
targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task['impl'], targeted_task[pipeline_framework]
task, model_class = targeted_task['impl'], targeted_task[framework]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
@ -843,14 +874,14 @@ def pipeline(task: str, model: Optional = None,
if isinstance(model, str):
# Handle transparent TF/PT model conversion
model_kwargs = {}
if pipeline_framework == 'pt' and model.endswith('.h5'):
if framework == 'pt' and model.endswith('.h5'):
model_kwargs['from_tf'] = True
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.')
elif pipeline_framework == 'tf' and model.endswith('.bin'):
elif framework == 'tf' and model.endswith('.bin'):
model_kwargs['from_pt'] = True
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.')
model = model_class.from_pretrained(model, config=config, **model_kwargs)
return task(model, tokenizer, **kwargs)
return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)