diff --git a/transformers/pipelines.py b/transformers/pipelines.py index 853735a2567..9acd9bc5664 100755 --- a/transformers/pipelines.py +++ b/transformers/pipelines.py @@ -385,5 +385,17 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni targeted_task = SUPPORTED_TASKS[task] task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt'] - model = allocator.from_pretrained(model) + # Special handling for model conversion + from_tf = model.endswith('.h5') and not is_tf_available() + from_pt = model.endswith('.bin') and not is_torch_available() + + if from_tf: + logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.') + elif from_pt: + logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.') + + if allocator.__name__.startswith('TF'): + model = allocator.from_pretrained(model, config=config, from_pt=from_pt) + else: + model = allocator.from_pretrained(model, config=config, from_tf=from_tf) return task(model, tokenizer, **kwargs)