mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
feat: add cache retention for requests (#38446)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* feat: add cache retention for requests * fix: propagate `manual_eviction` param & refactor `finish_request` `finish_request` now only takes `request_id: str` as an input rather than the full `RequestState`, which was not needed and simplifies calling from `ContinuousBatchingManager::evict_request_from_cache` * refactor: pop req from `active_requests` * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
66da700145
commit
8010f3cf61
@ -305,11 +305,12 @@ class Scheduler(ABC):
|
||||
It is expected that cache allocation and scheduling logic will be implemented in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: PagedAttentionCache):
|
||||
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
|
||||
self.active_requests: Dict[str, RequestState] = {}
|
||||
self.waiting_requests: Dict[str, RequestState] = {}
|
||||
self.waiting_requests_order: Deque[str] = deque()
|
||||
self.cache = cache
|
||||
self.retain_cache_on_finish = retain_cache_on_finish
|
||||
|
||||
@abstractmethod
|
||||
def add_waiting_request(self, state: RequestState):
|
||||
@ -326,7 +327,7 @@ class Scheduler(ABC):
|
||||
return self.active_requests or self.waiting_requests
|
||||
|
||||
@abstractmethod
|
||||
def finish_request(self, state: RequestState):
|
||||
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
||||
"""Finish processing a request and free its allocated blocks."""
|
||||
pass
|
||||
|
||||
@ -385,6 +386,11 @@ class FIFOScheduler(Scheduler):
|
||||
@traced
|
||||
def add_waiting_request(self, state: RequestState):
|
||||
"""Add a request to the waiting list."""
|
||||
if self.retain_cache_on_finish and state.request_id in self.active_requests:
|
||||
old_state = self.active_requests.pop(state.request_id)
|
||||
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :]
|
||||
state.allocated_blocks = old_state.allocated_blocks
|
||||
state.position_offset = old_state.position_offset
|
||||
self.waiting_requests[state.request_id] = state
|
||||
self.waiting_requests_order.append(state.request_id)
|
||||
|
||||
@ -444,11 +450,11 @@ class FIFOScheduler(Scheduler):
|
||||
return scheduled_requests
|
||||
|
||||
@traced
|
||||
def finish_request(self, state: RequestState):
|
||||
request_id = state.request_id
|
||||
self.cache.free_blocks(request_id)
|
||||
if request_id in self.active_requests:
|
||||
del self.active_requests[request_id]
|
||||
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
||||
if evict_from_cache:
|
||||
self.cache.free_blocks(request_id)
|
||||
if request_id in self.active_requests:
|
||||
del self.active_requests[request_id]
|
||||
|
||||
|
||||
@attach_tracer()
|
||||
@ -499,6 +505,11 @@ class PrefillFirstScheduler(Scheduler):
|
||||
@traced
|
||||
def add_waiting_request(self, state: RequestState):
|
||||
"""Add a request to the waiting list."""
|
||||
if self.retain_cache_on_finish and state.request_id in self.active_requests:
|
||||
old_state = self.active_requests.pop(state.request_id)
|
||||
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
|
||||
state.allocated_blocks = old_state.allocated_blocks
|
||||
state.position_offset = old_state.position_offset
|
||||
self.waiting_requests[state.request_id] = state
|
||||
self.waiting_requests_order.append(state.request_id)
|
||||
|
||||
@ -558,11 +569,11 @@ class PrefillFirstScheduler(Scheduler):
|
||||
return scheduled_requests
|
||||
|
||||
@traced
|
||||
def finish_request(self, state: RequestState):
|
||||
request_id = state.request_id
|
||||
self.cache.free_blocks(request_id)
|
||||
if request_id in self.active_requests:
|
||||
del self.active_requests[request_id]
|
||||
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
||||
if evict_from_cache:
|
||||
self.cache.free_blocks(request_id)
|
||||
if request_id in self.active_requests:
|
||||
del self.active_requests[request_id]
|
||||
|
||||
|
||||
@traced(standalone=True)
|
||||
@ -717,6 +728,7 @@ class ContinuousBatchProcessor:
|
||||
model_dtype: torch.dtype,
|
||||
scheduler: Scheduler,
|
||||
streaming: bool = False,
|
||||
manual_eviction: bool = False,
|
||||
):
|
||||
"""Initialize the continuous batch processor.
|
||||
|
||||
@ -740,6 +752,7 @@ class ContinuousBatchProcessor:
|
||||
self.model_dtype = model_dtype
|
||||
self.scheduler = scheduler
|
||||
self.streaming = streaming
|
||||
self.manual_eviction = manual_eviction
|
||||
|
||||
self.requests_in_batch: List[RequestState] = []
|
||||
|
||||
@ -1002,7 +1015,7 @@ class ContinuousBatchProcessor:
|
||||
state.prompt_ids = [token]
|
||||
if state.update_with_token(token):
|
||||
self.metrics.record_request_completion(state.created_time, state.request_id)
|
||||
self.scheduler.finish_request(state)
|
||||
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
||||
finished_request_ids.append(req_id)
|
||||
self._maybe_send_output(state, token)
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
@ -1019,7 +1032,7 @@ class ContinuousBatchProcessor:
|
||||
failed_reqs = self.requests_in_batch
|
||||
for req in failed_reqs:
|
||||
self._handle_request_error(error, req)
|
||||
self.scheduler.finish_request(req)
|
||||
self.scheduler.finish_request(req.request_id)
|
||||
|
||||
@traced
|
||||
def fail_all_requests(self, error):
|
||||
@ -1030,7 +1043,7 @@ class ContinuousBatchProcessor:
|
||||
"""
|
||||
for state in self.scheduler.active_requests.values():
|
||||
self._handle_request_error(error, state)
|
||||
self.scheduler.finish_request(state)
|
||||
self.scheduler.finish_request(state.request_id)
|
||||
|
||||
# Also fail any requests in the waiting queue
|
||||
for req_id in list(self.scheduler.waiting_requests.keys()):
|
||||
@ -1056,7 +1069,14 @@ class ContinuousBatchingManager:
|
||||
retrieving results, and managing the background generation thread.
|
||||
"""
|
||||
|
||||
def __init__(self, model, generation_config: GenerationConfig, max_queue_size=0, streaming: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
generation_config: GenerationConfig,
|
||||
manual_eviction: bool = False,
|
||||
max_queue_size=0,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""Initialize the continuous batching manager.
|
||||
|
||||
Args:
|
||||
@ -1080,6 +1100,8 @@ class ContinuousBatchingManager:
|
||||
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
|
||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
|
||||
@traced
|
||||
def start(self):
|
||||
@ -1262,9 +1284,11 @@ class ContinuousBatchingManager:
|
||||
self.stop_event,
|
||||
self.model.device,
|
||||
self.model.dtype,
|
||||
scheduler(paged_attention_cache),
|
||||
scheduler(paged_attention_cache, self.manual_eviction),
|
||||
self.streaming,
|
||||
self.manual_eviction,
|
||||
)
|
||||
self.batch_processor = batch_processor
|
||||
is_first = True
|
||||
|
||||
if self.profile:
|
||||
@ -1346,6 +1370,14 @@ class ContinuousBatchingManager:
|
||||
if batch_processor is not None:
|
||||
batch_processor.fail_all_requests(error)
|
||||
|
||||
@traced
|
||||
def evict_request_from_cache(self, request_id: str):
|
||||
"""Evict a request from the cache. It is assumed that the request is already finished."""
|
||||
if not self.manual_eviction:
|
||||
raise RuntimeError("Manual eviction is not enabled for this manager.")
|
||||
if self.batch_processor is not None:
|
||||
self.batch_processor.scheduler.finish_request(request_id)
|
||||
|
||||
|
||||
class ContinuousMixin:
|
||||
"""Mixin class for models to add continuous batching capabilities."""
|
||||
@ -1353,8 +1385,8 @@ class ContinuousMixin:
|
||||
def init_continuous_batching(
|
||||
self,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
manual_eviction: bool = False,
|
||||
max_queue_size: int = 0,
|
||||
scheduler: str = "fifo",
|
||||
streaming: bool = False,
|
||||
) -> ContinuousBatchingManager:
|
||||
"""Initialize a manager for continuous batching inference.
|
||||
@ -1380,7 +1412,11 @@ class ContinuousMixin:
|
||||
|
||||
# Create and return the manager
|
||||
return ContinuousBatchingManager(
|
||||
model=self, generation_config=gen_config, max_queue_size=max_queue_size, streaming=streaming
|
||||
model=self,
|
||||
generation_config=gen_config,
|
||||
manual_eviction=manual_eviction,
|
||||
max_queue_size=max_queue_size,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
@traced
|
||||
|
Loading…
Reference in New Issue
Block a user