[pipeline] A simple fix for half-precision & 8bit models (#21479)

* v1 fix

* adapt from suggestions

* make style

* fix tests

* add gpu tests

* update docs

* fix other tests

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* better fix

* make fixup

* better example

* revert changes

* proposal

* more elegant solution

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-02-10 10:26:17 +01:00 committed by GitHub
parent 97d3390fc8
commit f83942684d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 12 deletions

View File

@ -105,6 +105,8 @@ If the model is too large for a single GPU, you can set `device_map="auto"` to a
generator(model="openai/whisper-large", device_map="auto")
```
Note that if `device_map="auto"` is passed, there is no need to add the argument `device=device` when instantiating your `pipeline` as you may encounter some unexpected behavior!
### Batch size
By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases.
@ -257,4 +259,32 @@ sudo apt install -y tesseract-ocr
pip install pytesseract
```
</Tip>
</Tip>
## Using `pipeline` on large models with 🤗 `accelerate`:
You can easily run `pipeline` on large models using 🤗 `accelerate`! First make sure you have installed `accelerate` with `pip install accelerate`.
First load your model using `device_map="auto"`! We will use `facebook/opt-1.3b` for our example.
```py
# pip install accelerate
import torch
from transformers import pipeline
pipe = pipeline(model="facebook/opt-1.3b", torch_dtype=torch.bfloat16, device_map="auto")
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
```
You can also pass 8-bit loaded models if you install `bitsandbytes` and add the argument `load_in_8bit=True`
```py
# pip install accelerate bitsandbytes
import torch
from transformers import pipeline
pipe = pipeline(model="facebook/opt-1.3b", device_map="auto", model_kwargs={"load_in_8bit": True})
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
```
Note that you can replace the checkpoint with any of the Hugging Face model that supports large model loading such as BLOOM!

View File

@ -738,6 +738,11 @@ def pipeline(
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
" arguments might conflict, use only one.)"
)
if device is not None:
logger.warning(
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
)
model_kwargs["device_map"] = device_map
if torch_dtype is not None:
if "torch_dtype" in model_kwargs:

View File

@ -286,9 +286,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
installed. If no framework is specified, will default to the one currently installed. If no framework is
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
no model is provided.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
device (Union[`int`, `torch.device`], *optional*):
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
model on the associated CUDA device id.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
[PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)

View File

@ -749,7 +749,7 @@ class Pipeline(_ScikitCompat):
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, str, "torch.device"] = -1,
device: Union[int, str, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs,
@ -764,6 +764,19 @@ class Pipeline(_ScikitCompat):
self.image_processor = image_processor
self.modelcard = modelcard
self.framework = framework
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=device)
if device is None:
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
@ -774,14 +787,10 @@ class Pipeline(_ScikitCompat):
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Special handling
if self.framework == "pt" and self.device.type != "cpu":
self.model = self.model.to(self.device)
# Update config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:

View File

@ -255,7 +255,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
device: int = -1,
task: str = "",
**kwargs,
):
@ -264,7 +263,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
device=device,
task=task,
**kwargs,
)

View File

@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
pipe("This is a test")
@require_torch
@require_accelerate
@require_torch_gpu
def test_pipeline_accelerate_top_p(self):
import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
pipe("This is a test", do_sample=True, top_p=0.5)