feat: add cache retention for requests (#38446)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* feat: add cache retention for requests

* fix: propagate `manual_eviction` param & refactor `finish_request`

`finish_request` now only takes `request_id: str` as an input rather
than the full `RequestState`, which was not needed and simplifies
calling from `ContinuousBatchingManager::evict_request_from_cache`

* refactor: pop req from `active_requests`

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Luc Georges 2025-05-28 20:15:10 +02:00 committed by GitHub
parent 66da700145
commit 8010f3cf61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -305,11 +305,12 @@ class Scheduler(ABC):
It is expected that cache allocation and scheduling logic will be implemented in subclasses.
"""
def __init__(self, cache: PagedAttentionCache):
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
self.active_requests: Dict[str, RequestState] = {}
self.waiting_requests: Dict[str, RequestState] = {}
self.waiting_requests_order: Deque[str] = deque()
self.cache = cache
self.retain_cache_on_finish = retain_cache_on_finish
@abstractmethod
def add_waiting_request(self, state: RequestState):
@ -326,7 +327,7 @@ class Scheduler(ABC):
return self.active_requests or self.waiting_requests
@abstractmethod
def finish_request(self, state: RequestState):
def finish_request(self, request_id: str, evict_from_cache: bool = True):
"""Finish processing a request and free its allocated blocks."""
pass
@ -385,6 +386,11 @@ class FIFOScheduler(Scheduler):
@traced
def add_waiting_request(self, state: RequestState):
"""Add a request to the waiting list."""
if self.retain_cache_on_finish and state.request_id in self.active_requests:
old_state = self.active_requests.pop(state.request_id)
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :]
state.allocated_blocks = old_state.allocated_blocks
state.position_offset = old_state.position_offset
self.waiting_requests[state.request_id] = state
self.waiting_requests_order.append(state.request_id)
@ -444,11 +450,11 @@ class FIFOScheduler(Scheduler):
return scheduled_requests
@traced
def finish_request(self, state: RequestState):
request_id = state.request_id
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
def finish_request(self, request_id: str, evict_from_cache: bool = True):
if evict_from_cache:
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
@attach_tracer()
@ -499,6 +505,11 @@ class PrefillFirstScheduler(Scheduler):
@traced
def add_waiting_request(self, state: RequestState):
"""Add a request to the waiting list."""
if self.retain_cache_on_finish and state.request_id in self.active_requests:
old_state = self.active_requests.pop(state.request_id)
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
state.allocated_blocks = old_state.allocated_blocks
state.position_offset = old_state.position_offset
self.waiting_requests[state.request_id] = state
self.waiting_requests_order.append(state.request_id)
@ -558,11 +569,11 @@ class PrefillFirstScheduler(Scheduler):
return scheduled_requests
@traced
def finish_request(self, state: RequestState):
request_id = state.request_id
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
def finish_request(self, request_id: str, evict_from_cache: bool = True):
if evict_from_cache:
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
@traced(standalone=True)
@ -717,6 +728,7 @@ class ContinuousBatchProcessor:
model_dtype: torch.dtype,
scheduler: Scheduler,
streaming: bool = False,
manual_eviction: bool = False,
):
"""Initialize the continuous batch processor.
@ -740,6 +752,7 @@ class ContinuousBatchProcessor:
self.model_dtype = model_dtype
self.scheduler = scheduler
self.streaming = streaming
self.manual_eviction = manual_eviction
self.requests_in_batch: List[RequestState] = []
@ -1002,7 +1015,7 @@ class ContinuousBatchProcessor:
state.prompt_ids = [token]
if state.update_with_token(token):
self.metrics.record_request_completion(state.created_time, state.request_id)
self.scheduler.finish_request(state)
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
finished_request_ids.append(req_id)
self._maybe_send_output(state, token)
elif state.status == RequestStatus.PREFILLING_SPLIT:
@ -1019,7 +1032,7 @@ class ContinuousBatchProcessor:
failed_reqs = self.requests_in_batch
for req in failed_reqs:
self._handle_request_error(error, req)
self.scheduler.finish_request(req)
self.scheduler.finish_request(req.request_id)
@traced
def fail_all_requests(self, error):
@ -1030,7 +1043,7 @@ class ContinuousBatchProcessor:
"""
for state in self.scheduler.active_requests.values():
self._handle_request_error(error, state)
self.scheduler.finish_request(state)
self.scheduler.finish_request(state.request_id)
# Also fail any requests in the waiting queue
for req_id in list(self.scheduler.waiting_requests.keys()):
@ -1056,7 +1069,14 @@ class ContinuousBatchingManager:
retrieving results, and managing the background generation thread.
"""
def __init__(self, model, generation_config: GenerationConfig, max_queue_size=0, streaming: bool = True):
def __init__(
self,
model,
generation_config: GenerationConfig,
manual_eviction: bool = False,
max_queue_size=0,
streaming: bool = True,
):
"""Initialize the continuous batching manager.
Args:
@ -1080,6 +1100,8 @@ class ContinuousBatchingManager:
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
@traced
def start(self):
@ -1262,9 +1284,11 @@ class ContinuousBatchingManager:
self.stop_event,
self.model.device,
self.model.dtype,
scheduler(paged_attention_cache),
scheduler(paged_attention_cache, self.manual_eviction),
self.streaming,
self.manual_eviction,
)
self.batch_processor = batch_processor
is_first = True
if self.profile:
@ -1346,6 +1370,14 @@ class ContinuousBatchingManager:
if batch_processor is not None:
batch_processor.fail_all_requests(error)
@traced
def evict_request_from_cache(self, request_id: str):
"""Evict a request from the cache. It is assumed that the request is already finished."""
if not self.manual_eviction:
raise RuntimeError("Manual eviction is not enabled for this manager.")
if self.batch_processor is not None:
self.batch_processor.scheduler.finish_request(request_id)
class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""
@ -1353,8 +1385,8 @@ class ContinuousMixin:
def init_continuous_batching(
self,
generation_config: Optional[GenerationConfig] = None,
manual_eviction: bool = False,
max_queue_size: int = 0,
scheduler: str = "fifo",
streaming: bool = False,
) -> ContinuousBatchingManager:
"""Initialize a manager for continuous batching inference.
@ -1380,7 +1412,11 @@ class ContinuousMixin:
# Create and return the manager
return ContinuousBatchingManager(
model=self, generation_config=gen_config, max_queue_size=max_queue_size, streaming=streaming
model=self,
generation_config=gen_config,
manual_eviction=manual_eviction,
max_queue_size=max_queue_size,
streaming=streaming,
)
@traced