mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge 9b3ce7e9e3
into 2d561713f8
This commit is contained in:
commit
071bb487b5
@ -1139,11 +1139,12 @@ class StaticCache(Cache):
|
||||
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
||||
|
||||
self._dtype = dtype
|
||||
tp_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
) // tp_size
|
||||
|
||||
self.key_cache: list[torch.Tensor] = []
|
||||
self.value_cache: list[torch.Tensor] = []
|
||||
|
Loading…
Reference in New Issue
Block a user