mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix cache get item return type hints (#37847)
F: Fix cache return hints Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
aa6b79db43
commit
da7ae467c4
@ -376,7 +376,7 @@ class DynamicCache(Cache):
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||
sequence length.
|
||||
@ -649,7 +649,7 @@ class OffloadedCache(DynamicCache):
|
||||
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||
if layer_idx < len(self):
|
||||
# Evict the previous layer if necessary
|
||||
@ -1473,7 +1473,7 @@ class EncoderDecoderCache(Cache):
|
||||
for layer_idx in range(len(cross_attention_cache.key_cache)):
|
||||
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||
sequence length.
|
||||
|
Loading…
Reference in New Issue
Block a user