mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[generate] skip compilation on cpu offload (#37709)
* skip compilation on cpu offload * add test * better logic * docstring * boolean logic * add disk offload check * warn users if compilation options are set but compilation doesn happen * fix test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
7c62e69326
commit
8bdd4f2acd
@ -381,10 +381,12 @@ class GenerationConfig(PushToHubMixin):
|
||||
> Parameters related to performances and compilation
|
||||
|
||||
compile_config (CompileConfig, *optional*):
|
||||
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
|
||||
gains.
|
||||
|
||||
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag.
|
||||
If using a compilable cache, this controls how `generate` will `compile` the forward pass for faster
|
||||
inference.
|
||||
disable_compile (`bool`, *optional*):
|
||||
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
||||
specific criteria are met, including using a compileable cache. Please open an issue if you find the
|
||||
need to use this flag.
|
||||
|
||||
> Wild card
|
||||
|
||||
@ -489,7 +491,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||
|
||||
# Performance
|
||||
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
||||
self.compile_config = kwargs.pop("compile_config", None)
|
||||
self.disable_compile = kwargs.pop("disable_compile", False)
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
@ -811,9 +813,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 7. performances arguments
|
||||
if not isinstance(self.compile_config, CompileConfig):
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
|
||||
# 8. other incorrect combinations
|
||||
|
@ -2097,6 +2097,47 @@ class GenerationMixin:
|
||||
generation_config._pad_token_tensor = pad_token_tensor
|
||||
generation_config._decoder_start_token_tensor = decoder_start_token_tensor
|
||||
|
||||
def _valid_auto_compile_criteria(self, model_kwargs: Dict, generation_config: GenerationConfig) -> bool:
|
||||
"""
|
||||
Determines whether to trigger auto-compilation of the model's forward pass at generation time.
|
||||
"""
|
||||
# Override: honor `disable_compile` flag
|
||||
if generation_config.disable_compile:
|
||||
return False
|
||||
|
||||
# Base logic
|
||||
valid_hardware = self.device.type == "cuda" or bool(
|
||||
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
|
||||
)
|
||||
using_compilable_cache = (
|
||||
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
|
||||
)
|
||||
can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache
|
||||
|
||||
# Exception 1: Some quantization methods do not support compilation
|
||||
if getattr(self, "hf_quantizer", None) is not None:
|
||||
can_compile &= self.hf_quantizer.is_compileable
|
||||
|
||||
if hasattr(self, "hf_device_map"):
|
||||
all_model_devices = set(self.hf_device_map.values())
|
||||
# Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
|
||||
has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
|
||||
can_compile &= not has_cpu_offload
|
||||
|
||||
# Exception 3: Disk offload is not supported for compilation
|
||||
has_disk_offload = "disk" in all_model_devices
|
||||
can_compile &= not has_disk_offload
|
||||
|
||||
# Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
|
||||
# them
|
||||
if generation_config.compile_config is not None and not can_compile:
|
||||
logger.warning_once(
|
||||
"You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
|
||||
"will be skipped."
|
||||
)
|
||||
|
||||
return can_compile
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -3389,16 +3430,9 @@ class GenerationMixin:
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
model_forward = self.__call__
|
||||
if isinstance(model_kwargs.get("past_key_values"), Cache):
|
||||
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
||||
if getattr(self, "hf_quantizer", None) is not None:
|
||||
is_compileable &= self.hf_quantizer.is_compileable
|
||||
is_compileable = is_compileable and not generation_config.disable_compile
|
||||
if is_compileable and (
|
||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
if self._valid_auto_compile_criteria(model_kwargs, generation_config):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
if generation_config.prefill_chunk_size is not None:
|
||||
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
||||
|
@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
def loss_function(self, value):
|
||||
self._loss_function = value
|
||||
|
||||
def get_compiled_call(self, compile_config: CompileConfig):
|
||||
def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
|
||||
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
||||
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
||||
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
||||
@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Only reset it if not present or different from previous config
|
||||
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
||||
return self.__call__
|
||||
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
|
||||
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
|
||||
if (
|
||||
not hasattr(self, "_compiled_call")
|
||||
or getattr(self, "_last_compile_config", default_config) != compile_config
|
||||
|
@ -2245,13 +2245,15 @@ class GenerationTesterMixin:
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
compile_config = CompileConfig()
|
||||
compile_config._compile_all_devices = True
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
model.language_model.generation_config.compile_config = compile_config
|
||||
if not has_defined_cache_implementation:
|
||||
model.language_model.generation_config.cache_implementation = "static"
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU)
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
model.generation_config.compile_config = compile_config
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
@ -4907,6 +4909,37 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# If the generate doesn't infer the DECODER device map correctly, this will fail
|
||||
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_cpu_offload_doesnt_compile(self):
|
||||
"""Test that CPU offload doesn't trigger compilation"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
tokenized_inputs = tokenizer(["Hello world"], return_tensors="pt")
|
||||
generate_kwargs = {"max_new_tokens": 3, "cache_implementation": "static"}
|
||||
|
||||
# Sanity check: if we don't specify a device map, the model will get compiled
|
||||
model_gpu = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
|
||||
)
|
||||
input_ids = tokenized_inputs.input_ids.to(model_gpu.device)
|
||||
_ = model_gpu.generate(input_ids, **generate_kwargs)
|
||||
self.assertTrue(hasattr(model_gpu, "_compiled_call"))
|
||||
|
||||
# If we specify a device map, the model will not be compiled
|
||||
# (as of April 2025, compiling with CPU offload results in a crash)
|
||||
device_map = {
|
||||
"model.embed_tokens": 0,
|
||||
"model.layers.0": 0,
|
||||
"model.layers.1": "cpu",
|
||||
"model.norm": "cpu",
|
||||
"lm_head": 0,
|
||||
}
|
||||
model_cpu = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
input_ids = tokenized_inputs.input_ids.to(model_cpu.device)
|
||||
_ = model_cpu.generate(input_ids, **generate_kwargs)
|
||||
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user