mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Allow model conversion in the pipeline allocator.
This commit is contained in:
parent
28e64ad5a4
commit
1ca52567a4
@ -385,5 +385,17 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
|
||||
|
||||
model = allocator.from_pretrained(model)
|
||||
# Special handling for model conversion
|
||||
from_tf = model.endswith('.h5') and not is_tf_available()
|
||||
from_pt = model.endswith('.bin') and not is_torch_available()
|
||||
|
||||
if from_tf:
|
||||
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.')
|
||||
elif from_pt:
|
||||
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.')
|
||||
|
||||
if allocator.__name__.startswith('TF'):
|
||||
model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
|
||||
else:
|
||||
model = allocator.from_pretrained(model, config=config, from_tf=from_tf)
|
||||
return task(model, tokenizer, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user