mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Clean up framework handling
This commit is contained in:
parent
ca6bdb28f6
commit
1fa93ca1ea
@ -48,6 +48,19 @@ if is_torch_available():
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_framework(model=None):
|
||||
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
|
||||
# Both framework are available but the use supplied a model class instance.
|
||||
# Try to guess which framework to use from the model classname
|
||||
framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt'
|
||||
else:
|
||||
framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
|
||||
if framework is None:
|
||||
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/.")
|
||||
return framework
|
||||
|
||||
class ArgumentHandler(ABC):
|
||||
"""
|
||||
Base interface for handling varargs for each Pipeline
|
||||
@ -279,19 +292,23 @@ class Pipeline(_ScikitCompat):
|
||||
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
|
||||
"""
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None,
|
||||
modelcard: ModelCard = None, framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None, device: int = -1,
|
||||
binary_output: bool = False):
|
||||
|
||||
if framework is None:
|
||||
framework = get_framework()
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
self.device = device
|
||||
self.binary_output = binary_output
|
||||
self._args_parser = args_parser or DefaultArgumentHandler()
|
||||
|
||||
# Special handling
|
||||
if self.device >= 0 and not is_tf_available():
|
||||
if self.device >= 0 and self.framework == 'pt':
|
||||
self.model = self.model.to('cuda:{}'.format(self.device))
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
@ -332,7 +349,7 @@ class Pipeline(_ScikitCompat):
|
||||
Returns:
|
||||
Context manager
|
||||
"""
|
||||
if is_tf_available():
|
||||
if self.framework == 'tf':
|
||||
with tf.device('/CPU:0' if self.device == -1 else '/device:GPU:{}'.format(self.device)):
|
||||
yield
|
||||
else:
|
||||
@ -371,7 +388,7 @@ class Pipeline(_ScikitCompat):
|
||||
with self.device_placement():
|
||||
inputs = self.tokenizer.batch_encode_plus(
|
||||
inputs, add_special_tokens=True,
|
||||
return_tensors='tf' if is_tf_available() else 'pt',
|
||||
return_tensors=self.framework,
|
||||
max_length=self.tokenizer.max_len
|
||||
)
|
||||
|
||||
@ -387,7 +404,7 @@ class Pipeline(_ScikitCompat):
|
||||
Returns:
|
||||
Numpy array
|
||||
"""
|
||||
if is_tf_available():
|
||||
if self.framework == 'tf':
|
||||
# TODO trace model
|
||||
predictions = self.model(inputs, training=False)[0]
|
||||
else:
|
||||
@ -405,9 +422,16 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
def __init__(self, model,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None,
|
||||
framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: int = -1):
|
||||
super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output=True)
|
||||
super().__init__(model=model,
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
framework=framework,
|
||||
args_parser=args_parser,
|
||||
device=device,
|
||||
binary_output=True)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs).tolist()
|
||||
@ -430,10 +454,16 @@ class NerPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
|
||||
modelcard: ModelCard = None,
|
||||
modelcard: ModelCard = None, framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None, device: int = -1,
|
||||
binary_output: bool = False):
|
||||
super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output)
|
||||
super().__init__(model=model,
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
framework=framework,
|
||||
args_parser=args_parser,
|
||||
device=device,
|
||||
binary_output=binary_output)
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
|
||||
@ -452,12 +482,12 @@ class NerPipeline(Pipeline):
|
||||
|
||||
tokens = self.tokenizer.encode_plus(
|
||||
sentence, return_attention_mask=False,
|
||||
return_tensors='tf' if is_tf_available() else 'pt',
|
||||
return_tensors=self.framework,
|
||||
max_length=self.tokenizer.max_len
|
||||
)
|
||||
|
||||
# Forward
|
||||
if is_tf_available():
|
||||
if self.framework == 'tf':
|
||||
entities = self.model(tokens)[0][0].numpy()
|
||||
else:
|
||||
with torch.no_grad():
|
||||
@ -549,6 +579,18 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
Question Answering pipeline using ModelForQuestionAnswering head.
|
||||
"""
|
||||
|
||||
def __init__(self, model,
|
||||
tokenizer: Optional[PreTrainedTokenizer],
|
||||
modelcard: Optional[ModelCard],
|
||||
framework: Optional[str] = None,
|
||||
device: int = -1, **kwargs):
|
||||
super().__init__(model=model,
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
framework=framework,
|
||||
args_parser=QuestionAnsweringArgumentHandler(),
|
||||
device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
|
||||
"""
|
||||
@ -567,12 +609,6 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
else:
|
||||
return SquadExample(None, question, context, None, None, None)
|
||||
|
||||
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer],
|
||||
modelcard: Optional[ModelCard],
|
||||
device: int = -1, **kwargs):
|
||||
super().__init__(model, tokenizer, modelcard, args_parser=QuestionAnsweringArgumentHandler(),
|
||||
device=device, **kwargs)
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
@ -608,7 +644,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
# Manage tensor allocation on correct device
|
||||
with self.device_placement():
|
||||
if is_tf_available():
|
||||
if self.framework == 'tf':
|
||||
fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()}
|
||||
start, end = self.model(fw_args)
|
||||
start, end = start.numpy(), end.numpy()
|
||||
@ -798,15 +834,10 @@ def pipeline(task: str, model: Optional = None,
|
||||
if task not in SUPPORTED_TASKS:
|
||||
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
||||
|
||||
pipeline_framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
|
||||
if pipeline_framework is None:
|
||||
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
|
||||
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/.")
|
||||
|
||||
framework = get_framework(model)
|
||||
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
task, model_class = targeted_task['impl'], targeted_task[pipeline_framework]
|
||||
task, model_class = targeted_task['impl'], targeted_task[framework]
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
@ -843,14 +874,14 @@ def pipeline(task: str, model: Optional = None,
|
||||
if isinstance(model, str):
|
||||
# Handle transparent TF/PT model conversion
|
||||
model_kwargs = {}
|
||||
if pipeline_framework == 'pt' and model.endswith('.h5'):
|
||||
if framework == 'pt' and model.endswith('.h5'):
|
||||
model_kwargs['from_tf'] = True
|
||||
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
|
||||
'Trying to load the model with PyTorch.')
|
||||
elif pipeline_framework == 'tf' and model.endswith('.bin'):
|
||||
elif framework == 'tf' and model.endswith('.bin'):
|
||||
model_kwargs['from_pt'] = True
|
||||
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
|
||||
'Trying to load the model with Tensorflow.')
|
||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
||||
|
||||
return task(model, tokenizer, **kwargs)
|
||||
return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user