[generate] skip compilation on cpu offload (#37709)

* skip compilation on cpu offload

* add test

* better logic

* docstring

* boolean logic

* add disk offload check

* warn users if compilation options are set but compilation doesn happen

* fix test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-04-24 14:08:17 +01:00 committed by GitHub
parent 7c62e69326
commit 8bdd4f2acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 21 deletions

View File

@ -381,10 +381,12 @@ class GenerationConfig(PushToHubMixin):
> Parameters related to performances and compilation
compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag.
If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster
inference.
disable_compile (`bool`, *optional*):
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
specific criteria are met, including using a compileable cache. Please open an issue if you find the
need to use this flag.
> Wild card
@ -489,7 +491,7 @@ class GenerationConfig(PushToHubMixin):
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Performance
self.compile_config = kwargs.pop("compile_config", CompileConfig())
self.compile_config = kwargs.pop("compile_config", None)
self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
@ -811,9 +813,10 @@ class GenerationConfig(PushToHubMixin):
self.watermarking_config.validate()
# 7. performances arguments
if not isinstance(self.compile_config, CompileConfig):
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 8. other incorrect combinations

View File

@ -2097,6 +2097,47 @@ class GenerationMixin:
generation_config._pad_token_tensor = pad_token_tensor
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
def _valid_auto_compile_criteria(self, model_kwargs: Dict, generation_config: GenerationConfig) -> bool:
"""
Determines whether to trigger auto-compilation of the model's forward pass at generation time.
"""
# Override: honor `disable_compile` flag
if generation_config.disable_compile:
return False
# Base logic
valid_hardware = self.device.type == "cuda" or bool(
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
)
using_compilable_cache = (
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
)
can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache
# Exception 1: Some quantization methods do not support compilation
if getattr(self, "hf_quantizer", None) is not None:
can_compile &= self.hf_quantizer.is_compileable
if hasattr(self, "hf_device_map"):
all_model_devices = set(self.hf_device_map.values())
# Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
can_compile &= not has_cpu_offload
# Exception 3: Disk offload is not supported for compilation
has_disk_offload = "disk" in all_model_devices
can_compile &= not has_disk_offload
# Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
# them
if generation_config.compile_config is not None and not can_compile:
logger.warning_once(
"You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
"will be skipped."
)
return can_compile
@torch.no_grad()
def generate(
self,
@ -3389,16 +3430,9 @@ class GenerationMixin:
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
if getattr(self, "hf_quantizer", None) is not None:
is_compileable &= self.hf_quantizer.is_compileable
is_compileable = is_compileable and not generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
if self._valid_auto_compile_criteria(model_kwargs, generation_config):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)

View File

@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
def loss_function(self, value):
self._loss_function = value
def get_compiled_call(self, compile_config: CompileConfig):
def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config

View File

@ -2245,13 +2245,15 @@ class GenerationTesterMixin:
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
compile_config = CompileConfig()
compile_config._compile_all_devices = True
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
model.language_model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation:
model.language_model.generation_config.cache_implementation = "static"
else:
# force compilation (e.g. fast CI, CPU)
model.generation_config.compile_config._compile_all_devices = True
model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase):
# If the generate doesn't infer the DECODER device map correctly, this will fail
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
@require_torch_gpu
def test_cpu_offload_doesnt_compile(self):
"""Test that CPU offload doesn't trigger compilation"""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt")
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"}
# Sanity check: if we don't specify a device map, the model will get compiled
model_gpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
input_ids = tokenized_inputs.input_ids.to(model_gpu.device)
_ = model_gpu.generate(input_ids, **generate_kwargs)
self.assertTrue(hasattr(model_gpu, "_compiled_call"))
# If we specify a device map, the model will not be compiled
# (as of April 2025, compiling with CPU offload results in a crash)
device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": "cpu",
"model.norm": "cpu",
"lm_head": 0,
}
model_cpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
input_ids = tokenized_inputs.input_ids.to(model_cpu.device)
_ = model_cpu.generate(input_ids, **generate_kwargs)
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
@require_torch
class TokenHealingTestCase(unittest.TestCase):