mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Merge 9b3ce7e9e3
into 2d561713f8
This commit is contained in:
commit
071bb487b5
@ -1139,11 +1139,12 @@ class StaticCache(Cache):
|
|||||||
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
tp_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads
|
config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
else config.num_key_value_heads
|
else config.num_key_value_heads
|
||||||
)
|
) // tp_size
|
||||||
|
|
||||||
self.key_cache: list[torch.Tensor] = []
|
self.key_cache: list[torch.Tensor] = []
|
||||||
self.value_cache: list[torch.Tensor] = []
|
self.value_cache: list[torch.Tensor] = []
|
||||||
|
Loading…
Reference in New Issue
Block a user