mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Cache] Don't initialize the cache on meta
device (#36543)
This commit is contained in:
parent
79254c9b61
commit
c4161238bd
@ -10,7 +10,6 @@ from packaging import version
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_hqq_available():
|
||||
@ -1064,18 +1063,19 @@ class StaticCache(Cache):
|
||||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||||
batch_size (`int`):
|
||||
The batch size with which the model will be used. Note that a new instance must be instantiated if a
|
||||
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
|
||||
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
|
||||
number of beams if you are running beam search
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device` or `str`):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
||||
device by default, and then moved to input device when updating.
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
|
||||
Example:
|
||||
@ -1101,7 +1101,6 @@ class StaticCache(Cache):
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -1128,7 +1127,6 @@ class StaticCache(Cache):
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
@ -1139,11 +1137,12 @@ class StaticCache(Cache):
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||
cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
device = torch.device(device) if device is not None else None
|
||||
for idx in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
layer_device = layer_device_map[idx]
|
||||
else:
|
||||
layer_device = self.device
|
||||
layer_device = device
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
@ -1178,12 +1177,7 @@ class StaticCache(Cache):
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
if self.key_cache[layer_idx].device.type == "meta":
|
||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
@ -1211,8 +1205,6 @@ class StaticCache(Cache):
|
||||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||
# limit the check to the first batch member and head dimension.
|
||||
# TODO: deprecate this function in favor of `cache_position`
|
||||
if self.key_cache[layer_idx].device.type == "meta":
|
||||
return 0
|
||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
@ -1221,10 +1213,9 @@ class StaticCache(Cache):
|
||||
def reset(self):
|
||||
"""Resets the cache values while preserving the objects"""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
if self.key_cache[layer_idx].device.type != "meta":
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
@ -1261,14 +1252,14 @@ class SlidingWindowCache(StaticCache):
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device` or `str`):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
||||
device by default, and then moved to input device when updating.
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1329,11 +1320,6 @@ class SlidingWindowCache(StaticCache):
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
|
||||
if self.key_cache[layer_idx].device.type == "meta":
|
||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
@ -1380,10 +1366,9 @@ class SlidingWindowCache(StaticCache):
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
if self.key_cache[layer_idx].device.type != "meta":
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
|
||||
class EncoderDecoderCache(Cache):
|
||||
@ -1573,14 +1558,14 @@ class HybridCache(Cache):
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
||||
device by default, and then moved to input device when updating.
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1607,7 +1592,6 @@ class HybridCache(Cache):
|
||||
# is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -1642,7 +1626,6 @@ class HybridCache(Cache):
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding = torch.tensor(
|
||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||
@ -1656,11 +1639,12 @@ class HybridCache(Cache):
|
||||
min(config.sliding_window, max_cache_len),
|
||||
self.head_dim,
|
||||
)
|
||||
device = torch.device(device) if device is not None else None
|
||||
for i in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
layer_device = layer_device_map[i]
|
||||
else:
|
||||
layer_device = self.device
|
||||
layer_device = device
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
@ -1717,9 +1701,12 @@ class HybridCache(Cache):
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
sliding_window = cache_kwargs.get("sliding_window")
|
||||
|
||||
if self.key_cache[layer_idx].device.type == "meta":
|
||||
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
|
||||
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
|
||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||
if self.key_cache[layer_idx].device != key_states.device:
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
||||
if self.value_cache[layer_idx].device != value_states.device:
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
@ -1753,18 +1740,14 @@ class HybridCache(Cache):
|
||||
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
||||
"Using the `layer_idx` argument is not supported."
|
||||
)
|
||||
|
||||
if self.key_cache[layer_idx].device.type == "meta":
|
||||
return 0
|
||||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the cache values while preserving the objects"""
|
||||
for layer_idx in range(len(self.key_cache)):
|
||||
if self.key_cache[layer_idx].device.type != "meta":
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
@ -1789,24 +1772,6 @@ class MambaCache:
|
||||
The default `dtype` to use when initializing the layer.
|
||||
device (`torch.device` or `str`, *optional*):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
|
||||
device by default, and then moved to input device when updating.
|
||||
|
||||
Attributes:
|
||||
dtype: (`torch.dtype`):
|
||||
The default `dtype` used to initializing the cache.
|
||||
device (`torch.device`):
|
||||
The default device on which the cache was initialized.
|
||||
intermediate_size: (`int`):
|
||||
Model's intermediate_size taken from config.
|
||||
ssm_state_size: (`int`):
|
||||
Model's state_size taken from config.
|
||||
conv_kernel_size: (`int`):
|
||||
Model's convolution kernel size taken from config
|
||||
conv_states: (`torch.Tensor`):
|
||||
A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
|
||||
ssm_states: (`torch.Tensor`):
|
||||
A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
|
||||
|
||||
Example:
|
||||
|
||||
@ -1829,6 +1794,7 @@ class MambaCache:
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -1847,23 +1813,23 @@ class MambaCache:
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
|
||||
self.conv_states: List[torch.Tensor] = []
|
||||
self.ssm_states: List[torch.Tensor] = []
|
||||
device = torch.device(device) if device is not None else None
|
||||
for _ in range(config.num_hidden_layers):
|
||||
conv_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.conv_kernel_size,
|
||||
device=self.device,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
ssm_state: torch.Tensor = torch.zeros(
|
||||
self.max_batch_size,
|
||||
self.intermediate_size,
|
||||
self.ssm_state_size,
|
||||
device=self.device,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@ -1875,11 +1841,10 @@ class MambaCache:
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
if self.conv_states[layer_idx].device.type == "meta":
|
||||
self.conv_states[layer_idx] = torch.zeros_like(
|
||||
self.conv_states[layer_idx],
|
||||
device=new_conv_state.device,
|
||||
)
|
||||
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
|
||||
# when the cache is initialized in the forward pass (e.g. Mamba)
|
||||
if self.conv_states[layer_idx].device != new_conv_state.device:
|
||||
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
|
||||
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
@ -1896,10 +1861,9 @@ class MambaCache:
|
||||
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.conv_states)):
|
||||
if self.conv_states[layer_idx].device.type != "meta":
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
# In-place ops prevent breaking the static address
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.ssm_states[layer_idx].zero_()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
@ -1924,33 +1888,16 @@ class OffloadedStaticCache(StaticCache):
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`Union[str, torch.device]`):
|
||||
The device on which the cache should be initialized. Should be the same as the
|
||||
layer device.
|
||||
The device on which the cache should be initialized. If you're using more than 1 computation device, you
|
||||
should pass the `layer_device_map` argument instead.
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
The default `dtype` to use when initializing the cache.
|
||||
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
|
||||
The device to offload to. Defaults to CPU.
|
||||
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
|
||||
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Attributes:
|
||||
key_cache (`List[torch.Tensor]`):
|
||||
Off-loaded key cache tensors. First one will be on device, where-as the others are
|
||||
off-loaded.
|
||||
value_cache (`List[torch.Tensor]`):
|
||||
Off-loaded value cache tensors. First one will be on device, where-as the others are
|
||||
off-loaded.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which this cache can be used.
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which this cache can be used.
|
||||
device (`torch.device`):
|
||||
The device on which the cache is used.
|
||||
offload_device (`torch.device`):
|
||||
The device used to offload to.
|
||||
dtype (`torch.dtype`):
|
||||
The `dtype` used to initializing the cache.
|
||||
Mapping between the layers and its device. This is required when you are manually initializing the cache
|
||||
and the model is splitted between differents gpus. You can know which layers mapped to which device by
|
||||
checking the associated device_map: `model.hf_device_map`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1973,7 +1920,6 @@ class OffloadedStaticCache(StaticCache):
|
||||
|
||||
is_compileable = True
|
||||
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
|
@ -483,7 +483,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
||||
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||
|
||||
# Performances
|
||||
# Performance
|
||||
self.compile_config = kwargs.pop("compile_config", CompileConfig())
|
||||
self.disable_compile = kwargs.pop("disable_compile", False)
|
||||
# Wild card
|
||||
|
@ -1618,6 +1618,40 @@ class GenerationMixin:
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
def _get_layer_device_map_for_cache_init(self):
|
||||
"""
|
||||
Taken from `dispatch_model` from accelerate.
|
||||
This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||
For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||
"""
|
||||
execution_device_map = None
|
||||
|
||||
if hasattr(self, "hf_device_map"):
|
||||
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
|
||||
main_device = "cpu"
|
||||
else:
|
||||
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
||||
execution_device_map = {
|
||||
name: main_device if device in ["cpu", "disk"] else device
|
||||
for name, device in self.hf_device_map.items()
|
||||
}
|
||||
|
||||
num_hidden_layers = self.config.get_text_config().num_hidden_layers
|
||||
if execution_device_map is None:
|
||||
return None
|
||||
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
||||
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
|
||||
layer_device_map = {}
|
||||
for layer in execution_device_map:
|
||||
for idx in range(num_hidden_layers):
|
||||
if f".{idx}." in f"{layer}.":
|
||||
layer_device_map[idx] = execution_device_map[layer]
|
||||
break
|
||||
for idx in range(num_hidden_layers):
|
||||
if idx not in layer_device_map:
|
||||
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||
return layer_device_map
|
||||
|
||||
def _get_cache(
|
||||
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||
) -> Cache:
|
||||
@ -1664,12 +1698,14 @@ class GenerationMixin:
|
||||
# models. May cause trobles with non-text modalities.
|
||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||
|
||||
layer_device_map = self._get_layer_device_map_for_cache_init()
|
||||
cache_kwargs = {
|
||||
"config": self.config.get_text_config(),
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"dtype": cache_dtype,
|
||||
"device": device if cache_implementation == "offloaded_static" else None,
|
||||
"device": device,
|
||||
"layer_device_map": layer_device_map,
|
||||
}
|
||||
self._cache = cache_cls(**cache_kwargs)
|
||||
if requires_cross_attention_cache:
|
||||
|
@ -597,11 +597,13 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
|
@ -488,11 +488,13 @@ class Cohere2Model(Gemma2Model):
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
|
@ -599,11 +599,13 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
|
@ -437,11 +437,13 @@ class Gemma2Model(GemmaModel):
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
batch_size, seq_len, _ = inputs_embeds.shape
|
||||
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
|
||||
past_key_values = HybridCache(
|
||||
self.config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=seq_len,
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
|
@ -2304,45 +2304,6 @@ class GenerationTesterMixin:
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@is_flaky
|
||||
def test_assisted_decoding_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
assistant_model = model
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
"assistant_model": assistant_model,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
|
||||
)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
|
||||
|
||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_inherits_generation_mixin(self):
|
||||
"""
|
||||
|
@ -20,6 +20,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(
|
||||
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_static_cache_no_cuda_graph_skips(self):
|
||||
"""
|
||||
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
|
||||
|
||||
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
|
||||
messages are being thrown to stderr?)
|
||||
"""
|
||||
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
||||
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
|
||||
|
||||
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
|
||||
with CaptureStderr() as cap:
|
||||
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
|
||||
self.assertEqual(cap.err, "")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_static_cache_multi_gpu(self):
|
||||
"""Regression test for #35164: static cache with multi-gpu"""
|
||||
|
||||
model_id = "google/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
|
||||
num_hidden_layers = 26
|
||||
for i in range(num_hidden_layers):
|
||||
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype="bfloat16",
|
||||
device_map=device_map,
|
||||
)
|
||||
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
|
||||
_ = model(**inputs)
|
||||
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||
|
Loading…
Reference in New Issue
Block a user