[generate] skip compilation on cpu offload (#37709)

* skip compilation on cpu offload

* add test

* better logic

* docstring

* boolean logic

* add disk offload check

* warn users if compilation options are set but compilation doesn happen

* fix test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-04-24 14:08:17 +01:00 committed by GitHub
parent 7c62e69326
commit 8bdd4f2acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 21 deletions

View File

@ -381,10 +381,12 @@ class GenerationConfig(PushToHubMixin):
> Parameters related to performances and compilation > Parameters related to performances and compilation
compile_config (CompileConfig, *optional*): compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster
gains. inference.
disable_compile (`bool`, *optional*):
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag. Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
specific criteria are met, including using a compileable cache. Please open an issue if you find the
need to use this flag.
> Wild card > Wild card
@ -489,7 +491,7 @@ class GenerationConfig(PushToHubMixin):
self.target_lookbehind = kwargs.pop("target_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Performance # Performance
self.compile_config = kwargs.pop("compile_config", CompileConfig()) self.compile_config = kwargs.pop("compile_config", None)
self.disable_compile = kwargs.pop("disable_compile", False) self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card # Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {}) self.generation_kwargs = kwargs.pop("generation_kwargs", {})
@ -811,9 +813,10 @@ class GenerationConfig(PushToHubMixin):
self.watermarking_config.validate() self.watermarking_config.validate()
# 7. performances arguments # 7. performances arguments
if not isinstance(self.compile_config, CompileConfig): if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError( raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`." f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
) )
# 8. other incorrect combinations # 8. other incorrect combinations

View File

@ -2097,6 +2097,47 @@ class GenerationMixin:
generation_config._pad_token_tensor = pad_token_tensor generation_config._pad_token_tensor = pad_token_tensor
generation_config._decoder_start_token_tensor = decoder_start_token_tensor generation_config._decoder_start_token_tensor = decoder_start_token_tensor
def _valid_auto_compile_criteria(self, model_kwargs: Dict, generation_config: GenerationConfig) -> bool:
"""
Determines whether to trigger auto-compilation of the model's forward pass at generation time.
"""
# Override: honor `disable_compile` flag
if generation_config.disable_compile:
return False
# Base logic
valid_hardware = self.device.type == "cuda" or bool(
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
)
using_compilable_cache = (
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
)
can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache
# Exception 1: Some quantization methods do not support compilation
if getattr(self, "hf_quantizer", None) is not None:
can_compile &= self.hf_quantizer.is_compileable
if hasattr(self, "hf_device_map"):
all_model_devices = set(self.hf_device_map.values())
# Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
can_compile &= not has_cpu_offload
# Exception 3: Disk offload is not supported for compilation
has_disk_offload = "disk" in all_model_devices
can_compile &= not has_disk_offload
# Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
# them
if generation_config.compile_config is not None and not can_compile:
logger.warning_once(
"You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
"will be skipped."
)
return can_compile
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
@ -3389,16 +3430,9 @@ class GenerationMixin:
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_forward = self.__call__ model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache): if self._valid_auto_compile_criteria(model_kwargs, generation_config):
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache os.environ["TOKENIZERS_PARALLELISM"] = "0"
if getattr(self, "hf_quantizer", None) is not None: model_forward = self.get_compiled_call(generation_config.compile_config)
is_compileable &= self.hf_quantizer.is_compileable
is_compileable = is_compileable and not generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None: if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)

View File

@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
def loss_function(self, value): def loss_function(self, value):
self._loss_function = value self._loss_function = value
def get_compiled_call(self, compile_config: CompileConfig): def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Only reset it if not present or different from previous config # Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__ return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig()) default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
if ( if (
not hasattr(self, "_compiled_call") not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config or getattr(self, "_last_compile_config", default_config) != compile_config

View File

@ -2245,13 +2245,15 @@ class GenerationTesterMixin:
# BLIP is the only exception with custom generate which call `self.lm.generate()` # BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()` # We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality # compatible with multimodality
compile_config = CompileConfig()
compile_config._compile_all_devices = True
if "blip" in model.__class__.__name__.lower(): if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True model.language_model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation: if not has_defined_cache_implementation:
model.language_model.generation_config.cache_implementation = "static" model.language_model.generation_config.cache_implementation = "static"
else: else:
# force compilation (e.g. fast CI, CPU) # force compilation (e.g. fast CI, CPU)
model.generation_config.compile_config._compile_all_devices = True model.generation_config.compile_config = compile_config
if not has_defined_cache_implementation: if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static" model.generation_config.cache_implementation = "static"
@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase):
# If the generate doesn't infer the DECODER device map correctly, this will fail # If the generate doesn't infer the DECODER device map correctly, this will fail
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False) _ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
@require_torch_gpu
def test_cpu_offload_doesnt_compile(self):
"""Test that CPU offload doesn't trigger compilation"""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt")
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"}
# Sanity check: if we don't specify a device map, the model will get compiled
model_gpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
input_ids = tokenized_inputs.input_ids.to(model_gpu.device)
_ = model_gpu.generate(input_ids, **generate_kwargs)
self.assertTrue(hasattr(model_gpu, "_compiled_call"))
# If we specify a device map, the model will not be compiled
# (as of April 2025, compiling with CPU offload results in a crash)
device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": "cpu",
"model.norm": "cpu",
"lm_head": 0,
}
model_cpu = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
input_ids = tokenized_inputs.input_ids.to(model_cpu.device)
_ = model_cpu.generate(input_ids, **generate_kwargs)
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
@require_torch @require_torch
class TokenHealingTestCase(unittest.TestCase): class TokenHealingTestCase(unittest.TestCase):