mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Adds use_auth_token with pipelines (#11123)
* added model_kwargs to infer_framework_from_model * added model_kwargs to tokenizer * added use_auth_token as named parameter * added dynamic get for use_auth_token
This commit is contained in:
parent
1c15128312
commit
3fd7eee18f
@ -246,6 +246,7 @@ def pipeline(
|
||||
framework: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
model_kwargs: Dict[str, Any] = {},
|
||||
**kwargs
|
||||
) -> Pipeline:
|
||||
@ -308,6 +309,10 @@ def pipeline(
|
||||
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
|
||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
|
||||
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
model_kwargs:
|
||||
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
|
||||
**model_kwargs)` function.
|
||||
@ -367,6 +372,9 @@ def pipeline(
|
||||
|
||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
||||
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
|
||||
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
|
||||
|
||||
# Instantiate tokenizer if needed
|
||||
if isinstance(tokenizer, (str, tuple)):
|
||||
if isinstance(tokenizer, tuple):
|
||||
@ -377,12 +385,12 @@ def pipeline(
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task
|
||||
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
|
||||
)
|
||||
|
||||
# Instantiate config if needed
|
||||
if isinstance(config, str):
|
||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task)
|
||||
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
|
||||
|
||||
# Instantiate modelcard if needed
|
||||
if isinstance(modelcard, str):
|
||||
|
@ -48,7 +48,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def infer_framework_from_model(
|
||||
model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None, task: Optional[str] = None
|
||||
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
|
||||
):
|
||||
"""
|
||||
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
|
||||
@ -65,10 +65,11 @@ def infer_framework_from_model(
|
||||
from.
|
||||
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
|
||||
A mapping framework to class.
|
||||
revision (:obj:`str`, `optional`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
task (:obj:`str`):
|
||||
The task defining which pipeline will be returned.
|
||||
model_kwargs:
|
||||
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
|
||||
**model_kwargs)` function.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple`: A tuple framework, model.
|
||||
@ -80,19 +81,20 @@ def infer_framework_from_model(
|
||||
"To install PyTorch, read the instructions at https://pytorch.org/."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model_kwargs["_from_pipeline"] = task
|
||||
if is_torch_available() and not is_tf_available():
|
||||
model_class = model_classes.get("pt", AutoModel)
|
||||
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
model_class = model_classes.get("tf", TFAutoModel)
|
||||
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
else:
|
||||
try:
|
||||
model_class = model_classes.get("pt", AutoModel)
|
||||
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
except OSError:
|
||||
model_class = model_classes.get("tf", TFAutoModel)
|
||||
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
|
||||
model = model_class.from_pretrained(model, **model_kwargs)
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
return framework, model
|
||||
|
Loading…
Reference in New Issue
Block a user