[Cache] Don't initialize the cache on meta device (#36543)

This commit is contained in:
Joao Gante 2025-03-13 10:13:29 +00:00 committed by GitHub
parent 79254c9b61
commit c4161238bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 138 additions and 147 deletions

View File

@ -10,7 +10,6 @@ from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg
if is_hqq_available():
@ -1064,18 +1063,19 @@ class StaticCache(Cache):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
number of beams if you are running beam search
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
@ -1101,7 +1101,6 @@ class StaticCache(Cache):
is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
@ -1128,7 +1127,6 @@ class StaticCache(Cache):
)
self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
@ -1139,11 +1137,12 @@ class StaticCache(Cache):
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
device = torch.device(device) if device is not None else None
for idx in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = self.device
layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
@ -1178,12 +1177,7 @@ class StaticCache(Cache):
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
@ -1211,8 +1205,6 @@ class StaticCache(Cache):
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def get_max_cache_shape(self) -> Optional[int]:
@ -1221,10 +1213,9 @@ class StaticCache(Cache):
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
@ -1261,14 +1252,14 @@ class SlidingWindowCache(StaticCache):
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
@ -1329,11 +1320,6 @@ class SlidingWindowCache(StaticCache):
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
@ -1380,10 +1366,9 @@ class SlidingWindowCache(StaticCache):
def reset(self):
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
class EncoderDecoderCache(Cache):
@ -1573,14 +1558,14 @@ class HybridCache(Cache):
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
@ -1607,7 +1592,6 @@ class HybridCache(Cache):
# is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,
@ -1642,7 +1626,6 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
@ -1656,11 +1639,12 @@ class HybridCache(Cache):
min(config.sliding_window, max_cache_len),
self.head_dim,
)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = self.device
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
@ -1717,9 +1701,12 @@ class HybridCache(Cache):
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
if self.key_cache[layer_idx].device.type == "meta":
self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device)
self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device)
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
if self.key_cache[layer_idx].device != key_states.device:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
@ -1753,18 +1740,14 @@ class HybridCache(Cache):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
if self.key_cache[layer_idx].device.type == "meta":
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
"""Resets the cache values while preserving the objects"""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
@property
def batch_size(self):
@ -1789,24 +1772,6 @@ class MambaCache:
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta`
device by default, and then moved to input device when updating.
Attributes:
dtype: (`torch.dtype`):
The default `dtype` used to initializing the cache.
device (`torch.device`):
The default device on which the cache was initialized.
intermediate_size: (`int`):
Model's intermediate_size taken from config.
ssm_state_size: (`int`):
Model's state_size taken from config.
conv_kernel_size: (`int`):
Model's convolution kernel size taken from config
conv_states: (`torch.Tensor`):
A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
ssm_states: (`torch.Tensor`):
A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
Example:
@ -1829,6 +1794,7 @@ class MambaCache:
is_compileable = True
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config: PretrainedConfig,
@ -1847,23 +1813,23 @@ class MambaCache:
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
device = torch.device(device) if device is not None else None
for _ in range(config.num_hidden_layers):
conv_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=self.device,
device=device,
dtype=dtype,
)
ssm_state: torch.Tensor = torch.zeros(
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=self.device,
device=device,
dtype=dtype,
)
@ -1875,11 +1841,10 @@ class MambaCache:
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
if self.conv_states[layer_idx].device.type == "meta":
self.conv_states[layer_idx] = torch.zeros_like(
self.conv_states[layer_idx],
device=new_conv_state.device,
)
# This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
# when the cache is initialized in the forward pass (e.g. Mamba)
if self.conv_states[layer_idx].device != new_conv_state.device:
self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
@ -1896,10 +1861,9 @@ class MambaCache:
def reset(self):
for layer_idx in range(len(self.conv_states)):
if self.conv_states[layer_idx].device.type != "meta":
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
# In-place ops prevent breaking the static address
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
@property
def batch_size(self):
@ -1924,33 +1888,16 @@ class OffloadedStaticCache(StaticCache):
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`Union[str, torch.device]`):
The device on which the cache should be initialized. Should be the same as the
layer device.
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*):
The default `dtype` to use when initializing the cache.
offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
The device to offload to. Defaults to CPU.
layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between different gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Attributes:
key_cache (`List[torch.Tensor]`):
Off-loaded key cache tensors. First one will be on device, where-as the others are
off-loaded.
value_cache (`List[torch.Tensor]`):
Off-loaded value cache tensors. First one will be on device, where-as the others are
off-loaded.
max_batch_size (`int`):
The maximum batch size with which this cache can be used.
max_cache_len (`int`):
The maximum sequence length with which this cache can be used.
device (`torch.device`):
The device on which the cache is used.
offload_device (`torch.device`):
The device used to offload to.
dtype (`torch.dtype`):
The `dtype` used to initializing the cache.
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.
Example:
@ -1973,7 +1920,6 @@ class OffloadedStaticCache(StaticCache):
is_compileable = True
@deprecate_kwarg("layer_device_map", version="4.52.0")
def __init__(
self,
config: PretrainedConfig,

View File

@ -483,7 +483,7 @@ class GenerationConfig(PushToHubMixin):
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Performances
# Performance
self.compile_config = kwargs.pop("compile_config", CompileConfig())
self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card

View File

@ -1618,6 +1618,40 @@ class GenerationMixin:
model_kwargs["cache_position"] = cache_position
return model_kwargs
def _get_layer_device_map_for_cache_init(self):
"""
Taken from `dispatch_model` from accelerate.
This is needed here if we don't want to make changes in accelerate in order to save execution_device
For offloaded case, we need to get the execution device, not just the device where it is offloaded
"""
execution_device_map = None
if hasattr(self, "hf_device_map"):
if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
num_hidden_layers = self.config.get_text_config().num_hidden_layers
if execution_device_map is None:
return None
elif len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(num_hidden_layers)}
layer_device_map = {}
for layer in execution_device_map:
for idx in range(num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map
def _get_cache(
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
@ -1664,12 +1698,14 @@ class GenerationMixin:
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype
layer_device_map = self._get_layer_device_map_for_cache_init()
cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"dtype": cache_dtype,
"device": device if cache_implementation == "offloaded_static" else None,
"device": device,
"layer_device_map": layer_device_map,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:

View File

@ -597,11 +597,13 @@ class Cohere2Model(Cohere2PreTrainedModel):
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None:

View File

@ -488,11 +488,13 @@ class Cohere2Model(Gemma2Model):
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None:

View File

@ -599,11 +599,13 @@ class Gemma2Model(Gemma2PreTrainedModel):
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None:

View File

@ -437,11 +437,13 @@ class Gemma2Model(GemmaModel):
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
# NOTE: ideally, `HybridCache` should be initialized outside the model with `layer_device_map`
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
dtype=inputs_embeds.dtype,
device=self.device,
)
if cache_position is None:

View File

@ -2304,45 +2304,6 @@ class GenerationTesterMixin:
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
@is_flaky
def test_assisted_decoding_with_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
assistant_model = model
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
}
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
# Setting logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
)
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
@pytest.mark.generate
def test_inherits_generation_mixin(self):
"""

View File

@ -20,6 +20,7 @@ from parameterized import parameterized
from transformers import set_seed
from transformers.testing_utils import (
CaptureStderr,
get_gpu_count,
is_torch_available,
require_gptq,
@ -654,3 +655,42 @@ class CacheIntegrationTest(unittest.TestCase):
torch.testing.assert_close(
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
)
@require_torch_gpu
def test_static_cache_no_cuda_graph_skips(self):
"""
Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
(? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
messages are being thrown to stderr?)
"""
model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)
inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
with CaptureStderr() as cap:
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
self.assertEqual(cap.err, "")
@require_torch_multi_gpu
def test_static_cache_multi_gpu(self):
"""Regression test for #35164: static cache with multi-gpu"""
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
num_hidden_layers = 26
for i in range(num_hidden_layers):
device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="bfloat16",
device_map=device_map,
)
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
_ = model(**inputs)
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")