This commit is contained in:
jiqing-feng 2025-07-02 11:51:22 -07:00 committed by GitHub
commit 071bb487b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1139,11 +1139,12 @@ class StaticCache(Cache):
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self._dtype = dtype
tp_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
) // tp_size
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []