Allow model conversion in the pipeline allocator.

This commit is contained in:
Morgan Funtowicz 2019-12-13 14:13:14 +01:00
parent 28e64ad5a4
commit 1ca52567a4

View File

@ -385,5 +385,17 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
targeted_task = SUPPORTED_TASKS[task]
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
model = allocator.from_pretrained(model)
# Special handling for model conversion
from_tf = model.endswith('.h5') and not is_tf_available()
from_pt = model.endswith('.bin') and not is_torch_available()
if from_tf:
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.')
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.')
if allocator.__name__.startswith('TF'):
model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
else:
model = allocator.from_pretrained(model, config=config, from_tf=from_tf)
return task(model, tokenizer, **kwargs)