Fix cache get item return type hints (#37847)

F: Fix cache return hints

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Chris 2025-04-29 06:23:52 -07:00 committed by GitHub
parent aa6b79db43
commit da7ae467c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -376,7 +376,7 @@ class DynamicCache(Cache):
self.key_cache.append(key_states)
self.value_cache.append(value_states)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
@ -649,7 +649,7 @@ class OffloadedCache(DynamicCache):
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
@ -1473,7 +1473,7 @@ class EncoderDecoderCache(Cache):
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.