[causal mask] fix preparation with multi-gpu (#37612)

* fix multi-gpu

* forgot non-copied models

* fixup
This commit is contained in:
Raushan Turganbay 2025-04-25 09:34:18 +02:00 committed by GitHub
parent 7bb619d710
commit 79d4bc761d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 278 additions and 481 deletions

View File

@ -557,10 +557,8 @@ class GenerationMixin:
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs[input_ids_key].shape
device = model_inputs[input_ids_key].device
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
@ -586,7 +584,6 @@ class GenerationMixin:
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,

View File

@ -1003,7 +1003,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1020,7 +1020,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1045,7 +1044,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1065,8 +1063,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1078,11 +1074,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1313,7 +1313,7 @@ class BambaModel(BambaPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
target_length = (
attention_mask.shape[-1]
@ -1327,7 +1327,6 @@ class BambaModel(BambaPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1352,7 +1351,6 @@ class BambaModel(BambaPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1372,8 +1370,6 @@ class BambaModel(BambaPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1385,11 +1381,11 @@ class BambaModel(BambaPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1081,7 +1081,7 @@ class BambaModel(BambaPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
target_length = (
attention_mask.shape[-1]
@ -1095,7 +1095,6 @@ class BambaModel(BambaPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1120,7 +1119,6 @@ class BambaModel(BambaPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1140,8 +1138,6 @@ class BambaModel(BambaPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1153,11 +1149,11 @@ class BambaModel(BambaPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -773,7 +773,7 @@ class BloomModel(BloomPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -790,7 +790,6 @@ class BloomModel(BloomPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -816,7 +815,6 @@ class BloomModel(BloomPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -836,8 +834,6 @@ class BloomModel(BloomPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -849,11 +845,11 @@ class BloomModel(BloomPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1406,7 +1406,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1423,7 +1423,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1449,7 +1448,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1469,8 +1467,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1482,11 +1478,11 @@ class ChameleonModel(ChameleonPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -619,7 +619,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -636,7 +636,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -662,7 +661,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -682,8 +680,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -695,11 +691,11 @@ class CodeGenModel(CodeGenPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -652,7 +652,7 @@ class CohereModel(CoherePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -669,7 +669,6 @@ class CohereModel(CoherePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -694,7 +693,6 @@ class CohereModel(CoherePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -714,8 +712,6 @@ class CohereModel(CoherePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -727,11 +723,11 @@ class CohereModel(CoherePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -686,7 +686,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -706,8 +705,6 @@ class Cohere2Model(Cohere2PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -719,11 +716,11 @@ class Cohere2Model(Cohere2PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1144,7 +1144,7 @@ class DbrxModel(DbrxPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1161,7 +1161,6 @@ class DbrxModel(DbrxPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1187,7 +1186,6 @@ class DbrxModel(DbrxPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1207,8 +1205,6 @@ class DbrxModel(DbrxPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1220,11 +1216,11 @@ class DbrxModel(DbrxPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -797,7 +797,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -814,7 +814,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -839,7 +838,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -859,8 +857,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -872,11 +868,11 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -900,7 +900,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -917,7 +917,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -942,7 +941,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -962,8 +960,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -975,11 +971,11 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1486,7 +1486,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1503,7 +1503,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1528,7 +1527,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1548,8 +1546,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1561,11 +1557,11 @@ class Emu3TextModel(Emu3PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1076,7 +1076,6 @@ class FalconModel(FalconPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1096,8 +1095,6 @@ class FalconModel(FalconPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1109,11 +1106,11 @@ class FalconModel(FalconPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -620,7 +620,7 @@ class GemmaModel(GemmaPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -637,7 +637,6 @@ class GemmaModel(GemmaPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -662,7 +661,6 @@ class GemmaModel(GemmaPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -682,8 +680,6 @@ class GemmaModel(GemmaPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -695,11 +691,11 @@ class GemmaModel(GemmaPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -711,7 +711,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -731,8 +730,6 @@ class Gemma2Model(Gemma2PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -744,11 +741,11 @@ class Gemma2Model(Gemma2PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -796,7 +796,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -816,8 +815,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -829,11 +826,11 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -635,7 +635,7 @@ class GlmModel(GlmPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -652,7 +652,6 @@ class GlmModel(GlmPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -677,7 +676,6 @@ class GlmModel(GlmPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -697,8 +695,6 @@ class GlmModel(GlmPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -710,11 +706,11 @@ class GlmModel(GlmPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -643,7 +643,7 @@ class Glm4Model(Glm4PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -660,7 +660,6 @@ class Glm4Model(Glm4PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -685,7 +684,6 @@ class Glm4Model(Glm4PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -705,8 +703,6 @@ class Glm4Model(Glm4PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -718,11 +714,11 @@ class Glm4Model(Glm4PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -820,7 +820,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -837,7 +837,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -863,7 +862,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -883,8 +881,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -896,11 +892,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -634,7 +634,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -651,7 +651,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -676,7 +675,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -696,8 +694,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -709,11 +705,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -671,7 +671,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -688,7 +688,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -714,7 +713,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -734,8 +732,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -747,11 +743,11 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -921,7 +921,7 @@ class GPTJModel(GPTJPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -938,7 +938,6 @@ class GPTJModel(GPTJPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -964,7 +963,6 @@ class GPTJModel(GPTJPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -984,8 +982,6 @@ class GPTJModel(GPTJPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -997,11 +993,11 @@ class GPTJModel(GPTJPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -635,7 +635,7 @@ class GraniteModel(GranitePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -652,7 +652,6 @@ class GraniteModel(GranitePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -677,7 +676,6 @@ class GraniteModel(GranitePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -697,8 +695,6 @@ class GraniteModel(GranitePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -710,11 +706,11 @@ class GraniteModel(GranitePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1120,7 +1120,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1137,7 +1137,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1163,7 +1162,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1183,8 +1181,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1196,11 +1192,11 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1065,7 +1065,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1082,7 +1082,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1107,7 +1106,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1127,8 +1125,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1140,11 +1136,11 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -620,7 +620,7 @@ class HeliumModel(HeliumPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -637,7 +637,6 @@ class HeliumModel(HeliumPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -662,7 +661,6 @@ class HeliumModel(HeliumPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -682,8 +680,6 @@ class HeliumModel(HeliumPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -695,11 +691,11 @@ class HeliumModel(HeliumPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1416,7 +1416,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1433,7 +1433,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1459,7 +1458,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1479,8 +1477,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1492,11 +1488,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1123,7 +1123,7 @@ class JetMoeModel(JetMoePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1140,7 +1140,6 @@ class JetMoeModel(JetMoePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1166,7 +1165,6 @@ class JetMoeModel(JetMoePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1186,8 +1184,6 @@ class JetMoeModel(JetMoePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1199,11 +1195,11 @@ class JetMoeModel(JetMoePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -625,7 +625,7 @@ class LlamaModel(LlamaPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -642,7 +642,6 @@ class LlamaModel(LlamaPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -667,7 +666,6 @@ class LlamaModel(LlamaPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -687,8 +685,6 @@ class LlamaModel(LlamaPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -700,11 +696,11 @@ class LlamaModel(LlamaPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -785,7 +785,6 @@ class Llama4TextModel(Llama4PreTrainedModel):
sequence_length=sequence_length,
target_length=max(full_cache_length, attention_chunk_size),
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1846,7 +1845,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1866,8 +1864,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1879,11 +1875,11 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1631,7 +1631,7 @@ class LongT5Stack(LongT5PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1648,7 +1648,6 @@ class LongT5Stack(LongT5PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1674,7 +1673,6 @@ class LongT5Stack(LongT5PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1694,8 +1692,6 @@ class LongT5Stack(LongT5PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1707,11 +1703,11 @@ class LongT5Stack(LongT5PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1086,7 +1086,7 @@ class MimiTransformerModel(nn.Module):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1106,7 +1106,6 @@ class MimiTransformerModel(nn.Module):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1133,7 +1132,6 @@ class MimiTransformerModel(nn.Module):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MimiConfig,
@ -1152,8 +1150,6 @@ class MimiTransformerModel(nn.Module):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1169,14 +1165,16 @@ class MimiTransformerModel(nn.Module):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -609,7 +609,7 @@ class MistralModel(MistralPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -629,7 +629,6 @@ class MistralModel(MistralPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -655,7 +654,6 @@ class MistralModel(MistralPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MistralConfig,
@ -674,8 +672,6 @@ class MistralModel(MistralPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -691,14 +687,16 @@ class MistralModel(MistralPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -171,7 +171,7 @@ class MistralModel(LlamaModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -191,7 +191,6 @@ class MistralModel(LlamaModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -217,7 +216,6 @@ class MistralModel(LlamaModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MistralConfig,
@ -236,8 +234,6 @@ class MistralModel(LlamaModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -253,14 +249,16 @@ class MistralModel(LlamaModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -751,7 +751,7 @@ class MixtralModel(MixtralPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -771,7 +771,6 @@ class MixtralModel(MixtralPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -797,7 +796,6 @@ class MixtralModel(MixtralPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MixtralConfig,
@ -816,8 +814,6 @@ class MixtralModel(MixtralPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -833,14 +829,16 @@ class MixtralModel(MixtralPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1092,7 +1092,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1109,7 +1109,6 @@ class MllamaPreTrainedModel(PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1135,7 +1134,6 @@ class MllamaPreTrainedModel(PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1155,8 +1153,6 @@ class MllamaPreTrainedModel(PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1168,11 +1164,11 @@ class MllamaPreTrainedModel(PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -968,7 +968,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -985,7 +985,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1010,7 +1009,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1030,8 +1028,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1043,11 +1039,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1312,7 +1312,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1332,7 +1332,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1359,7 +1358,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MoshiDepthConfig,
@ -1378,8 +1376,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1395,14 +1391,16 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
@ -1630,7 +1628,7 @@ class MoshiModel(MoshiPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1650,7 +1648,6 @@ class MoshiModel(MoshiPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1677,7 +1674,6 @@ class MoshiModel(MoshiPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: MoshiConfig,
@ -1696,8 +1692,6 @@ class MoshiModel(MoshiPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1713,14 +1707,16 @@ class MoshiModel(MoshiPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1222,7 +1222,7 @@ class MT5Stack(MT5PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1239,7 +1239,6 @@ class MT5Stack(MT5PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1265,7 +1264,6 @@ class MT5Stack(MT5PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1285,8 +1283,6 @@ class MT5Stack(MT5PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1298,11 +1294,11 @@ class MT5Stack(MT5PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -882,7 +882,7 @@ class NemotronModel(NemotronPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -899,7 +899,6 @@ class NemotronModel(NemotronPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -925,7 +924,6 @@ class NemotronModel(NemotronPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -945,8 +943,6 @@ class NemotronModel(NemotronPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -958,11 +954,11 @@ class NemotronModel(NemotronPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -595,7 +595,7 @@ class OlmoModel(OlmoPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -612,7 +612,6 @@ class OlmoModel(OlmoPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -637,7 +636,6 @@ class OlmoModel(OlmoPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -657,8 +655,6 @@ class OlmoModel(OlmoPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -670,11 +666,11 @@ class OlmoModel(OlmoPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -601,7 +601,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -618,7 +618,6 @@ class Olmo2Model(Olmo2PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -643,7 +642,6 @@ class Olmo2Model(Olmo2PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -663,8 +661,6 @@ class Olmo2Model(Olmo2PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -676,11 +672,11 @@ class Olmo2Model(Olmo2PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -672,7 +672,7 @@ class OPTDecoder(OPTPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -689,7 +689,6 @@ class OPTDecoder(OPTPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -715,7 +714,6 @@ class OPTDecoder(OPTPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -735,8 +733,6 @@ class OPTDecoder(OPTPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -748,11 +744,11 @@ class OPTDecoder(OPTPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -49,7 +49,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
@ -70,8 +69,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
@ -85,7 +82,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:

View File

@ -682,7 +682,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -699,7 +699,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -725,7 +724,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -745,8 +743,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -758,11 +754,11 @@ class PersimmonModel(PersimmonPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -606,7 +606,7 @@ class PhiModel(PhiPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -623,7 +623,6 @@ class PhiModel(PhiPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -648,7 +647,6 @@ class PhiModel(PhiPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -668,8 +666,6 @@ class PhiModel(PhiPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -681,11 +677,11 @@ class PhiModel(PhiPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -664,7 +664,7 @@ class Phi3Model(Phi3PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -684,7 +684,6 @@ class Phi3Model(Phi3PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -710,7 +709,6 @@ class Phi3Model(Phi3PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Phi3Config,
@ -729,8 +727,6 @@ class Phi3Model(Phi3PreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -746,14 +742,16 @@ class Phi3Model(Phi3PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1949,7 +1949,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1969,7 +1969,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1995,7 +1994,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Phi4MultimodalConfig,
@ -2014,8 +2012,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -2031,14 +2027,16 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1221,7 +1221,7 @@ class PhimoeModel(PhimoePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1241,7 +1241,6 @@ class PhimoeModel(PhimoePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1268,7 +1267,6 @@ class PhimoeModel(PhimoePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: PhimoeConfig,
@ -1287,8 +1285,6 @@ class PhimoeModel(PhimoePreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1304,14 +1300,16 @@ class PhimoeModel(PhimoePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1618,7 +1618,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1635,7 +1635,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1661,7 +1660,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1681,8 +1679,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1694,11 +1690,11 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1031,7 +1031,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1048,7 +1048,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1074,7 +1073,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1094,8 +1092,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1107,11 +1103,11 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -622,7 +622,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -642,7 +642,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -668,7 +667,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2Config,
@ -687,8 +685,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -704,14 +700,16 @@ class Qwen2Model(Qwen2PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -2088,7 +2088,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -2108,7 +2108,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -2134,7 +2133,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2_5OmniConfig,
@ -2153,8 +2151,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -2170,14 +2166,16 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
@ -2806,7 +2804,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -2826,7 +2824,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -2852,7 +2849,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2_5OmniConfig,
@ -2871,8 +2867,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -2888,14 +2882,16 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1266,7 +1266,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1286,7 +1286,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1312,7 +1311,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2_5_VLConfig,
@ -1331,8 +1329,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1348,14 +1344,16 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1082,7 +1082,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1102,7 +1102,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1129,7 +1128,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2MoeConfig,
@ -1148,8 +1146,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1165,14 +1161,16 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1218,7 +1218,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -1238,7 +1238,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -1265,7 +1264,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen2VLConfig,
@ -1284,8 +1282,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1301,14 +1297,16 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
@ -1692,9 +1690,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only

View File

@ -649,7 +649,7 @@ class Qwen3Model(Qwen3PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -669,7 +669,6 @@ class Qwen3Model(Qwen3PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -695,7 +694,6 @@ class Qwen3Model(Qwen3PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen3Config,
@ -714,8 +712,6 @@ class Qwen3Model(Qwen3PreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -731,14 +727,16 @@ class Qwen3Model(Qwen3PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -765,7 +765,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -785,7 +785,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -811,7 +810,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Qwen3MoeConfig,
@ -830,8 +828,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -847,14 +843,16 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -936,7 +936,7 @@ class StableLmModel(StableLmPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -953,7 +953,6 @@ class StableLmModel(StableLmPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -979,7 +978,6 @@ class StableLmModel(StableLmPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -999,8 +997,6 @@ class StableLmModel(StableLmPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1012,11 +1008,11 @@ class StableLmModel(StableLmPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -613,7 +613,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
@ -633,7 +633,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
@ -659,7 +658,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Starcoder2Config,
@ -678,8 +676,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -695,14 +691,16 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
-1, 1
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)

View File

@ -1165,7 +1165,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1182,7 +1182,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1208,7 +1207,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1228,8 +1226,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1241,11 +1237,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1236,7 +1236,7 @@ class T5Stack(T5PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1253,7 +1253,6 @@ class T5Stack(T5PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1279,7 +1278,6 @@ class T5Stack(T5PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1299,8 +1297,6 @@ class T5Stack(T5PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1312,11 +1308,11 @@ class T5Stack(T5PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1568,7 +1568,7 @@ class UdopStack(UdopPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1585,7 +1585,6 @@ class UdopStack(UdopPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1611,7 +1610,6 @@ class UdopStack(UdopPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1631,8 +1629,6 @@ class UdopStack(UdopPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1644,11 +1640,11 @@ class UdopStack(UdopPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -879,7 +879,7 @@ class UMT5Stack(UMT5PreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -896,7 +896,6 @@ class UMT5Stack(UMT5PreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -922,7 +921,6 @@ class UMT5Stack(UMT5PreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -942,8 +940,6 @@ class UMT5Stack(UMT5PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -955,11 +951,11 @@ class UMT5Stack(UMT5PreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -1406,7 +1406,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1423,7 +1423,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@ -1449,7 +1448,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@ -1469,8 +1467,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@ -1482,11 +1478,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -4018,6 +4018,34 @@ class GenerationIntegrationTests(unittest.TestCase):
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_multi_gpu_causal_mask(self):
"""
Tests that cache position device doesn't clash with causal mask device when we are using multi-gpus.
In real life happens only when multimodal encoder size is big, so `embed_tokens` gets allocated to the next device.
The error will be triggered whenever a bacthed input is used, so that `causal_mask` is actually prepared instead of
being `None`.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {
"visual": 0,
"model.embed_tokens": 1,
"model.layers.0": 1,
"model.layers.1": 1,
"model.rotary_emb": 1,
"model.norm.weight": 1,
"lm_head": 1,
}
model = AutoModelForImageTextToText.from_pretrained(
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", device_map=device_map
)
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
text = ["Hello world", "Today I went to the supermarket to buy"]
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
_ = model.generate(**inputs, max_new_tokens=20)
@pytest.mark.generate
@require_torch_multi_gpu
def test_init_static_cache_multi_gpu(self):