Automatic compilation in generate: do not rely on inner function (#34923)

* compiled forward in PreTrainedModel

* update

* style

* update name

* trigger CIs

* Add way to use custom compile args

* style

* switch parameterization to generation_config

* Add to inits

* Update configuration_utils.py

* inits

* style

* docs

* style

* Update configuration_utils.py

* back without dataclass for repo consistency

* Update configuration_utils.py

* style

* style

* style once again

* add config serialization

* update

* true dataclass

* trigger CIs

* merge compile methods + remove serialization of compile config
This commit is contained in:
Cyril Vallez 2024-12-03 11:20:31 +01:00 committed by GitHub
parent f9c7e6021e
commit ee37bf0d95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 99 additions and 12 deletions

View File

@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] SynthIDTextWatermarkDetector
- __call__
## Compile Utils
[[autodoc]] CompileConfig
- __call__

View File

@ -122,6 +122,7 @@ _import_structure = {
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [
"CompileConfig",
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
@ -4981,7 +4982,7 @@ if TYPE_CHECKING:
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser
# Integrations

View File

@ -20,6 +20,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure = {
"configuration_utils": [
"BaseWatermarkingConfig",
"CompileConfig",
"GenerationConfig",
"GenerationMode",
"SynthIDTextWatermarkingConfig",
@ -192,6 +193,7 @@ else:
if TYPE_CHECKING:
from .configuration_utils import (
BaseWatermarkingConfig,
CompileConfig,
GenerationConfig,
GenerationMode,
SynthIDTextWatermarkingConfig,

View File

@ -20,7 +20,7 @@ import os
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from .. import __version__
from ..configuration_utils import PretrainedConfig
@ -371,6 +371,12 @@ class GenerationConfig(PushToHubMixin):
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
> Parameters related to performances and compilation
compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.
> Wild card
generation_kwargs:
@ -474,6 +480,9 @@ class GenerationConfig(PushToHubMixin):
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Performances
self.compile_config = kwargs.pop("compile_config", CompileConfig())
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
@ -794,7 +803,13 @@ class GenerationConfig(PushToHubMixin):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 7. other incorrect combinations
# 7. performances arguments
if not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
)
# 8. other incorrect combinations
if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True:
@ -1175,6 +1190,8 @@ class GenerationConfig(PushToHubMixin):
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
if "compile_config" in output:
del output["compile_config"]
# Transformers version when serializing this file
output["transformers_version"] = __version__
@ -1559,3 +1576,51 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
skip_first_ngram_calls=self.skip_first_ngram_calls,
debug_mode=self.debug_mode,
)
@dataclass
class CompileConfig(object):
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Backend to be used.
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
Controls balance between performance and overhead.
options (`dict`, *optional*):
A dictionary of options to pass to the backend.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig
>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()
>>> # Automatic compile configuration, used with static cache
>>> compile_config = CompileConfig(dynamic=True)
>>> # Generation with static cache and compile config
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
>>> output = model.generate(
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
... )
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```
"""
fullgraph: bool = True
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
options: Optional[dict] = None
def to_dict(self) -> Dict[str, Any]:
"""Serializes this instance to a Python dictionary."""
return copy.deepcopy(self.__dict__)

View File

@ -3230,16 +3230,14 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
model_forward = self.get_compiled_call(generation_config.compile_config)
i = 0
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
@ -3250,11 +3248,11 @@ class GenerationMixin:
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
if i == 0:
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
i += 1
is_prefill = False
else:
outputs = model_forward(self, return_dict=True, **model_inputs)
outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(

View File

@ -43,7 +43,7 @@ from torch.utils.checkpoint import checkpoint
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
@ -5094,6 +5094,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config
):
self._last_compile_config = compile_config
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
return self._compiled_call
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None: