From 8bdd4f2acd9bd379b31dd21916a24f8150a1efa2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Apr 2025 14:08:17 +0100 Subject: [PATCH] [generate] skip compilation on cpu offload (#37709) * skip compilation on cpu offload * add test * better logic * docstring * boolean logic * add disk offload check * warn users if compilation options are set but compilation doesn happen * fix test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- .../generation/configuration_utils.py | 17 +++--- src/transformers/generation/utils.py | 54 +++++++++++++++---- src/transformers/modeling_utils.py | 4 +- tests/generation/test_utils.py | 37 ++++++++++++- 4 files changed, 91 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index f0ceaed5195..100b11fc748 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -381,10 +381,12 @@ class GenerationConfig(PushToHubMixin): > Parameters related to performances and compilation compile_config (CompileConfig, *optional*): - If using a static cache, this controls how `generate` will `compile` the forward pass for performance - gains. - - disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag. + If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster + inference. + disable_compile (`bool`, *optional*): + Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when + specific criteria are met, including using a compileable cache. Please open an issue if you find the + need to use this flag. > Wild card @@ -489,7 +491,7 @@ class GenerationConfig(PushToHubMixin): self.target_lookbehind = kwargs.pop("target_lookbehind", 10) # Performance - self.compile_config = kwargs.pop("compile_config", CompileConfig()) + self.compile_config = kwargs.pop("compile_config", None) self.disable_compile = kwargs.pop("disable_compile", False) # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) @@ -811,9 +813,10 @@ class GenerationConfig(PushToHubMixin): self.watermarking_config.validate() # 7. performances arguments - if not isinstance(self.compile_config, CompileConfig): + if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig): raise ValueError( - f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`." + f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an " + "instance of `CompileConfig`." ) # 8. other incorrect combinations diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2f9118b0dac..6dc9acf8a09 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2097,6 +2097,47 @@ class GenerationMixin: generation_config._pad_token_tensor = pad_token_tensor generation_config._decoder_start_token_tensor = decoder_start_token_tensor + def _valid_auto_compile_criteria(self, model_kwargs: Dict, generation_config: GenerationConfig) -> bool: + """ + Determines whether to trigger auto-compilation of the model's forward pass at generation time. + """ + # Override: honor `disable_compile` flag + if generation_config.disable_compile: + return False + + # Base logic + valid_hardware = self.device.type == "cuda" or bool( + generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices + ) + using_compilable_cache = ( + isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable + ) + can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache + + # Exception 1: Some quantization methods do not support compilation + if getattr(self, "hf_quantizer", None) is not None: + can_compile &= self.hf_quantizer.is_compileable + + if hasattr(self, "hf_device_map"): + all_model_devices = set(self.hf_device_map.values()) + # Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash) + has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1 + can_compile &= not has_cpu_offload + + # Exception 3: Disk offload is not supported for compilation + has_disk_offload = "disk" in all_model_devices + can_compile &= not has_disk_offload + + # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn + # them + if generation_config.compile_config is not None and not can_compile: + logger.warning_once( + "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation " + "will be skipped." + ) + + return can_compile + @torch.no_grad() def generate( self, @@ -3389,16 +3430,9 @@ class GenerationMixin: model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_forward = self.__call__ - if isinstance(model_kwargs.get("past_key_values"), Cache): - is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache - if getattr(self, "hf_quantizer", None) is not None: - is_compileable &= self.hf_quantizer.is_compileable - is_compileable = is_compileable and not generation_config.disable_compile - if is_compileable and ( - self.device.type == "cuda" or generation_config.compile_config._compile_all_devices - ): - os.environ["TOKENIZERS_PARALLELISM"] = "0" - model_forward = self.get_compiled_call(generation_config.compile_config) + if self._valid_auto_compile_criteria(model_kwargs, generation_config): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) if generation_config.prefill_chunk_size is not None: model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6999d46e492..adac0890e6d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def loss_function(self, value): self._loss_function = value - def get_compiled_call(self, compile_config: CompileConfig): + def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable: """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding @@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Only reset it if not present or different from previous config if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT return self.__call__ - default_config = getattr(self.generation_config, "compile_config", CompileConfig()) + default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig() if ( not hasattr(self, "_compiled_call") or getattr(self, "_last_compile_config", default_config) != compile_config diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1aa88abcb97..b9916e8dfa8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2245,13 +2245,15 @@ class GenerationTesterMixin: # BLIP is the only exception with custom generate which call `self.lm.generate()` # We should avoid such calls in all subsequent multimodal models and try to make `generate()` # compatible with multimodality + compile_config = CompileConfig() + compile_config._compile_all_devices = True if "blip" in model.__class__.__name__.lower(): - model.language_model.generation_config.compile_config._compile_all_devices = True + model.language_model.generation_config.compile_config = compile_config if not has_defined_cache_implementation: model.language_model.generation_config.cache_implementation = "static" else: # force compilation (e.g. fast CI, CPU) - model.generation_config.compile_config._compile_all_devices = True + model.generation_config.compile_config = compile_config if not has_defined_cache_implementation: model.generation_config.cache_implementation = "static" @@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase): # If the generate doesn't infer the DECODER device map correctly, this will fail _ = model.generate(**inputs, max_new_tokens=2, do_sample=False) + @require_torch_gpu + def test_cpu_offload_doesnt_compile(self): + """Test that CPU offload doesn't trigger compilation""" + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt") + generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"} + + # Sanity check: if we don't specify a device map, the model will get compiled + model_gpu = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" + ) + input_ids = tokenized_inputs.input_ids.to(model_gpu.device) + _ = model_gpu.generate(input_ids, **generate_kwargs) + self.assertTrue(hasattr(model_gpu, "_compiled_call")) + + # If we specify a device map, the model will not be compiled + # (as of April 2025, compiling with CPU offload results in a crash) + device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": "cpu", + "model.norm": "cpu", + "lm_head": 0, + } + model_cpu = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map + ) + input_ids = tokenized_inputs.input_ids.to(model_cpu.device) + _ = model_cpu.generate(input_ids, **generate_kwargs) + self.assertFalse(hasattr(model_cpu, "_compiled_call")) + @require_torch class TokenHealingTestCase(unittest.TestCase):