mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Automatic compilation in generate: do not rely on inner function (#34923)
* compiled forward in PreTrainedModel * update * style * update name * trigger CIs * Add way to use custom compile args * style * switch parameterization to generation_config * Add to inits * Update configuration_utils.py * inits * style * docs * style * Update configuration_utils.py * back without dataclass for repo consistency * Update configuration_utils.py * style * style * style once again * add config serialization * update * true dataclass * trigger CIs * merge compile methods + remove serialization of compile config
This commit is contained in:
parent
f9c7e6021e
commit
ee37bf0d95
@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] SynthIDTextWatermarkDetector
|
||||
- __call__
|
||||
|
||||
## Compile Utils
|
||||
|
||||
[[autodoc]] CompileConfig
|
||||
- __call__
|
||||
|
||||
|
@ -122,6 +122,7 @@ _import_structure = {
|
||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||
"file_utils": [],
|
||||
"generation": [
|
||||
"CompileConfig",
|
||||
"GenerationConfig",
|
||||
"TextIteratorStreamer",
|
||||
"TextStreamer",
|
||||
@ -4981,7 +4982,7 @@ if TYPE_CHECKING:
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
|
||||
# Generation
|
||||
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
||||
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
|
@ -20,6 +20,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
||||
_import_structure = {
|
||||
"configuration_utils": [
|
||||
"BaseWatermarkingConfig",
|
||||
"CompileConfig",
|
||||
"GenerationConfig",
|
||||
"GenerationMode",
|
||||
"SynthIDTextWatermarkingConfig",
|
||||
@ -192,6 +193,7 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import (
|
||||
BaseWatermarkingConfig,
|
||||
CompileConfig,
|
||||
GenerationConfig,
|
||||
GenerationMode,
|
||||
SynthIDTextWatermarkingConfig,
|
||||
|
@ -20,7 +20,7 @@ import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
@ -371,6 +371,12 @@ class GenerationConfig(PushToHubMixin):
|
||||
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
||||
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
||||
|
||||
> Parameters related to performances and compilation
|
||||
|
||||
compile_config (CompileConfig, *optional*):
|
||||
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
|
||||
gains.
|
||||
|
||||
> Wild card
|
||||
|
||||
generation_kwargs:
|
||||
@ -474,6 +480,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
||||
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||
|
||||
# Performances
|
||||
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
@ -794,7 +803,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 7. other incorrect combinations
|
||||
# 7. performances arguments
|
||||
if not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
|
||||
)
|
||||
|
||||
# 8. other incorrect combinations
|
||||
if self.return_dict_in_generate is not True:
|
||||
for extra_output_flag in self.extra_output_flags:
|
||||
if getattr(self, extra_output_flag) is True:
|
||||
@ -1175,6 +1190,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
del output["_commit_hash"]
|
||||
if "_original_object_hash" in output:
|
||||
del output["_original_object_hash"]
|
||||
if "compile_config" in output:
|
||||
del output["compile_config"]
|
||||
|
||||
# Transformers version when serializing this file
|
||||
output["transformers_version"] = __version__
|
||||
@ -1559,3 +1576,51 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
||||
skip_first_ngram_calls=self.skip_first_ngram_calls,
|
||||
debug_mode=self.debug_mode,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompileConfig(object):
|
||||
"""
|
||||
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
|
||||
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
|
||||
|
||||
Args:
|
||||
fullgraph (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, requires that the whole forward be capturable in a single graph.
|
||||
dynamic (`bool` or `None`, *optional*):
|
||||
Whether to try to use dynamic shape graphs.
|
||||
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
|
||||
Backend to be used.
|
||||
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
|
||||
Controls balance between performance and overhead.
|
||||
options (`dict`, *optional*):
|
||||
A dictionary of options to pass to the backend.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
|
||||
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
|
||||
|
||||
>>> # Automatic compile configuration, used with static cache
|
||||
>>> compile_config = CompileConfig(dynamic=True)
|
||||
|
||||
>>> # Generation with static cache and compile config
|
||||
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
|
||||
>>> output = model.generate(
|
||||
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
|
||||
... )
|
||||
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
|
||||
```
|
||||
"""
|
||||
|
||||
fullgraph: bool = True
|
||||
dynamic: Optional[bool] = None
|
||||
backend: Union[str, Callable] = "inductor"
|
||||
mode: str = "reduce-overhead"
|
||||
options: Optional[dict] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
@ -3230,16 +3230,14 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
def model_forward(model, *args, **kwargs):
|
||||
return model.forward(*args, **kwargs)
|
||||
|
||||
model_forward = self.__call__
|
||||
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
|
||||
if self.device.type == "cuda":
|
||||
logger.warning_once("Using `torch.compile`.")
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
i = 0
|
||||
is_prefill = True
|
||||
while self._has_unfinished_sequences(
|
||||
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
|
||||
):
|
||||
@ -3250,11 +3248,11 @@ class GenerationMixin:
|
||||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
||||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
||||
|
||||
if i == 0:
|
||||
if is_prefill:
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
i += 1
|
||||
is_prefill = False
|
||||
else:
|
||||
outputs = model_forward(self, return_dict=True, **model_inputs)
|
||||
outputs = model_forward(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
|
@ -43,7 +43,7 @@ from torch.utils.checkpoint import checkpoint
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import GenerationConfig, GenerationMixin
|
||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
@ -5094,6 +5094,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
loss_type = "ForCausalLM"
|
||||
return LOSS_MAPPING[loss_type]
|
||||
|
||||
def get_compiled_call(self, compile_config: CompileConfig):
|
||||
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
||||
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
||||
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
||||
(where we want the speed-ups of compiled version with static shapes)."""
|
||||
# Only reset it if not present or different from previous config
|
||||
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
|
||||
if (
|
||||
not hasattr(self, "_compiled_call")
|
||||
or getattr(self, "_last_compile_config", default_config) != compile_config
|
||||
):
|
||||
self._last_compile_config = compile_config
|
||||
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
|
||||
return self._compiled_call
|
||||
|
||||
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
if PreTrainedModel.push_to_hub.__doc__ is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user