mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[pipeline
] A simple fix for half-precision & 8bit models (#21479)
* v1 fix * adapt from suggestions * make style * fix tests * add gpu tests * update docs * fix other tests * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * better fix * make fixup * better example * revert changes * proposal * more elegant solution * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
97d3390fc8
commit
f83942684d
@ -105,6 +105,8 @@ If the model is too large for a single GPU, you can set `device_map="auto"` to a
|
||||
generator(model="openai/whisper-large", device_map="auto")
|
||||
```
|
||||
|
||||
Note that if `device_map="auto"` is passed, there is no need to add the argument `device=device` when instantiating your `pipeline` as you may encounter some unexpected behavior!
|
||||
|
||||
### Batch size
|
||||
|
||||
By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases.
|
||||
@ -257,4 +259,32 @@ sudo apt install -y tesseract-ocr
|
||||
pip install pytesseract
|
||||
```
|
||||
|
||||
</Tip>
|
||||
</Tip>
|
||||
|
||||
## Using `pipeline` on large models with 🤗 `accelerate`:
|
||||
|
||||
You can easily run `pipeline` on large models using 🤗 `accelerate`! First make sure you have installed `accelerate` with `pip install accelerate`.
|
||||
|
||||
First load your model using `device_map="auto"`! We will use `facebook/opt-1.3b` for our example.
|
||||
|
||||
```py
|
||||
# pip install accelerate
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(model="facebook/opt-1.3b", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
|
||||
```
|
||||
|
||||
You can also pass 8-bit loaded models if you install `bitsandbytes` and add the argument `load_in_8bit=True`
|
||||
|
||||
```py
|
||||
# pip install accelerate bitsandbytes
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(model="facebook/opt-1.3b", device_map="auto", model_kwargs={"load_in_8bit": True})
|
||||
output = pipe("This is a cool example!", do_sample=True, top_p=0.95)
|
||||
```
|
||||
|
||||
Note that you can replace the checkpoint with any of the Hugging Face model that supports large model loading such as BLOOM!
|
@ -738,6 +738,11 @@ def pipeline(
|
||||
'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those'
|
||||
" arguments might conflict, use only one.)"
|
||||
)
|
||||
if device is not None:
|
||||
logger.warning(
|
||||
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
||||
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
|
||||
)
|
||||
model_kwargs["device_map"] = device_map
|
||||
if torch_dtype is not None:
|
||||
if "torch_dtype" in model_kwargs:
|
||||
|
@ -286,9 +286,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
installed. If no framework is specified, will default to the one currently installed. If no framework is
|
||||
specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
|
||||
no model is provided.
|
||||
device (`int`, *optional*, defaults to -1):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
|
||||
the associated CUDA device id.
|
||||
device (Union[`int`, `torch.device`], *optional*):
|
||||
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
|
||||
model on the associated CUDA device id.
|
||||
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
|
||||
[PyCTCDecode's
|
||||
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
|
||||
|
@ -749,7 +749,7 @@ class Pipeline(_ScikitCompat):
|
||||
framework: Optional[str] = None,
|
||||
task: str = "",
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
device: Union[int, str, "torch.device"] = None,
|
||||
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||
binary_output: bool = False,
|
||||
**kwargs,
|
||||
@ -764,6 +764,19 @@ class Pipeline(_ScikitCompat):
|
||||
self.image_processor = image_processor
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
if self.framework == "pt" and device is not None:
|
||||
self.model = self.model.to(device=device)
|
||||
|
||||
if device is None:
|
||||
# `accelerate` device map
|
||||
hf_device_map = getattr(self.model, "hf_device_map", None)
|
||||
if hf_device_map is not None:
|
||||
# Take the first device used by `accelerate`.
|
||||
device = next(iter(hf_device_map.values()))
|
||||
else:
|
||||
device = -1
|
||||
|
||||
if is_torch_available() and self.framework == "pt":
|
||||
if isinstance(device, torch.device):
|
||||
self.device = device
|
||||
@ -774,14 +787,10 @@ class Pipeline(_ScikitCompat):
|
||||
else:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = device
|
||||
self.device = device if device is not None else -1
|
||||
self.torch_dtype = torch_dtype
|
||||
self.binary_output = binary_output
|
||||
|
||||
# Special handling
|
||||
if self.framework == "pt" and self.device.type != "cpu":
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Update config with task specific parameters
|
||||
task_specific_params = self.model.config.task_specific_params
|
||||
if task_specific_params is not None and task in task_specific_params:
|
||||
|
@ -255,7 +255,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: int = -1,
|
||||
task: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
@ -264,7 +263,6 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
||||
tokenizer=tokenizer,
|
||||
modelcard=modelcard,
|
||||
framework=framework,
|
||||
device=device,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -312,3 +312,12 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
|
||||
pipe("This is a test")
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_pipeline_accelerate_top_p(self):
|
||||
import torch
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||
|
Loading…
Reference in New Issue
Block a user