diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 2455f47c09f..fb1b959d468 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -246,6 +246,7 @@ def pipeline( framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, + use_auth_token: Optional[Union[str, bool]] = None, model_kwargs: Dict[str, Any] = {}, **kwargs ) -> Pipeline: @@ -308,6 +309,10 @@ def pipeline( artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`). + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): model_kwargs: Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., **model_kwargs)` function. @@ -367,6 +372,9 @@ def pipeline( task_class, model_class = targeted_task["impl"], targeted_task[framework] + # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained + model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) + # Instantiate tokenizer if needed if isinstance(tokenizer, (str, tuple)): if isinstance(tokenizer, tuple): @@ -377,12 +385,12 @@ def pipeline( ) else: tokenizer = AutoTokenizer.from_pretrained( - tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task + tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs ) # Instantiate config if needed if isinstance(config, str): - config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task) + config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) # Instantiate modelcard if needed if isinstance(modelcard, str): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 9da13796f58..d06376aa43c 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -48,7 +48,7 @@ logger = logging.get_logger(__name__) def infer_framework_from_model( - model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None, task: Optional[str] = None + model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs ): """ Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). @@ -65,10 +65,11 @@ def infer_framework_from_model( from. model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): A mapping framework to class. - revision (:obj:`str`, `optional`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any - identifier allowed by git. + task (:obj:`str`): + The task defining which pipeline will be returned. + model_kwargs: + Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., + **model_kwargs)` function. Returns: :obj:`Tuple`: A tuple framework, model. @@ -80,19 +81,20 @@ def infer_framework_from_model( "To install PyTorch, read the instructions at https://pytorch.org/." ) if isinstance(model, str): + model_kwargs["_from_pipeline"] = task if is_torch_available() and not is_tf_available(): model_class = model_classes.get("pt", AutoModel) - model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) + model = model_class.from_pretrained(model, **model_kwargs) elif is_tf_available() and not is_torch_available(): model_class = model_classes.get("tf", TFAutoModel) - model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) + model = model_class.from_pretrained(model, **model_kwargs) else: try: model_class = model_classes.get("pt", AutoModel) - model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) + model = model_class.from_pretrained(model, **model_kwargs) except OSError: model_class = model_classes.get("tf", TFAutoModel) - model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) + model = model_class.from_pretrained(model, **model_kwargs) framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" return framework, model