diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2ae16408b8d..c0f6da8d768 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -557,10 +557,8 @@ class GenerationMixin: if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs[input_ids_key].shape - device = model_inputs[input_ids_key].device # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder. @@ -586,7 +584,6 @@ class GenerationMixin: sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), dtype=self.dtype, - device=device, cache_position=cache_position, batch_size=batch_size, config=self.config, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 1b9892c94f8..8bab20bcee1 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1003,7 +1003,7 @@ class AriaTextModel(AriaTextPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1020,7 +1020,6 @@ class AriaTextModel(AriaTextPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1045,7 +1044,6 @@ class AriaTextModel(AriaTextPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1065,8 +1063,6 @@ class AriaTextModel(AriaTextPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1078,11 +1074,11 @@ class AriaTextModel(AriaTextPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 0fb696c1f73..074eea3aa97 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1313,7 +1313,7 @@ class BambaModel(BambaPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] target_length = ( attention_mask.shape[-1] @@ -1327,7 +1327,6 @@ class BambaModel(BambaPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1352,7 +1351,6 @@ class BambaModel(BambaPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1372,8 +1370,6 @@ class BambaModel(BambaPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1385,11 +1381,11 @@ class BambaModel(BambaPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 3e0eb513d5c..dd0d0e62c62 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -1081,7 +1081,7 @@ class BambaModel(BambaPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] target_length = ( attention_mask.shape[-1] @@ -1095,7 +1095,6 @@ class BambaModel(BambaPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1120,7 +1119,6 @@ class BambaModel(BambaPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1140,8 +1138,6 @@ class BambaModel(BambaPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1153,11 +1149,11 @@ class BambaModel(BambaPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f9aee9362db..a7293bf5415 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -773,7 +773,7 @@ class BloomModel(BloomPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -790,7 +790,6 @@ class BloomModel(BloomPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -816,7 +815,6 @@ class BloomModel(BloomPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -836,8 +834,6 @@ class BloomModel(BloomPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -849,11 +845,11 @@ class BloomModel(BloomPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 1c83ddea5a7..9a7d43bdb5e 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1406,7 +1406,7 @@ class ChameleonModel(ChameleonPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1423,7 +1423,6 @@ class ChameleonModel(ChameleonPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1449,7 +1448,6 @@ class ChameleonModel(ChameleonPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1469,8 +1467,6 @@ class ChameleonModel(ChameleonPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1482,11 +1478,11 @@ class ChameleonModel(ChameleonPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 1bc9ceeaa3b..0444c278d75 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -619,7 +619,7 @@ class CodeGenModel(CodeGenPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -636,7 +636,6 @@ class CodeGenModel(CodeGenPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -662,7 +661,6 @@ class CodeGenModel(CodeGenPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -682,8 +680,6 @@ class CodeGenModel(CodeGenPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -695,11 +691,11 @@ class CodeGenModel(CodeGenPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9d13b304862..610144d43c9 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -652,7 +652,7 @@ class CohereModel(CoherePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -669,7 +669,6 @@ class CohereModel(CoherePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -694,7 +693,6 @@ class CohereModel(CoherePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -714,8 +712,6 @@ class CohereModel(CoherePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -727,11 +723,11 @@ class CohereModel(CoherePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index cc790124ccd..ab6055eaa11 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -686,7 +686,6 @@ class Cohere2Model(Cohere2PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -706,8 +705,6 @@ class Cohere2Model(Cohere2PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -719,11 +716,11 @@ class Cohere2Model(Cohere2PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 729e877cc8f..565a1b62390 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1144,7 +1144,7 @@ class DbrxModel(DbrxPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1161,7 +1161,6 @@ class DbrxModel(DbrxPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1187,7 +1186,6 @@ class DbrxModel(DbrxPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1207,8 +1205,6 @@ class DbrxModel(DbrxPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1220,11 +1216,11 @@ class DbrxModel(DbrxPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 511c401f7e3..043f1fbcbf0 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -797,7 +797,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -814,7 +814,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -839,7 +838,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -859,8 +857,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -872,11 +868,11 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index b2d736cb4f9..10bed9ccaa5 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -900,7 +900,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -917,7 +917,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -942,7 +941,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -962,8 +960,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -975,11 +971,11 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 375b1beb230..f5a626a21fa 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1486,7 +1486,7 @@ class Emu3TextModel(Emu3PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1503,7 +1503,6 @@ class Emu3TextModel(Emu3PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1528,7 +1527,6 @@ class Emu3TextModel(Emu3PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1548,8 +1546,6 @@ class Emu3TextModel(Emu3PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1561,11 +1557,11 @@ class Emu3TextModel(Emu3PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 86b974570d6..b3f5b12f626 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1076,7 +1076,6 @@ class FalconModel(FalconPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1096,8 +1095,6 @@ class FalconModel(FalconPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1109,11 +1106,11 @@ class FalconModel(FalconPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 481fa70eb25..0d011a5f916 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -620,7 +620,7 @@ class GemmaModel(GemmaPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -637,7 +637,6 @@ class GemmaModel(GemmaPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -662,7 +661,6 @@ class GemmaModel(GemmaPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -682,8 +680,6 @@ class GemmaModel(GemmaPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -695,11 +691,11 @@ class GemmaModel(GemmaPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fd63ec26c1c..f779a74fae0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -711,7 +711,6 @@ class Gemma2Model(Gemma2PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -731,8 +730,6 @@ class Gemma2Model(Gemma2PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -744,11 +741,11 @@ class Gemma2Model(Gemma2PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 951e8d78ca9..79da2149d9e 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -796,7 +796,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -816,8 +815,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -829,11 +826,11 @@ class Gemma3TextModel(Gemma3PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 9a688e0b8e6..1d72c48232a 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -635,7 +635,7 @@ class GlmModel(GlmPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -652,7 +652,6 @@ class GlmModel(GlmPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -677,7 +676,6 @@ class GlmModel(GlmPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -697,8 +695,6 @@ class GlmModel(GlmPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -710,11 +706,11 @@ class GlmModel(GlmPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index faf188387db..aee3d975727 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -643,7 +643,7 @@ class Glm4Model(Glm4PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -660,7 +660,6 @@ class Glm4Model(Glm4PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -685,7 +684,6 @@ class Glm4Model(Glm4PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -705,8 +703,6 @@ class Glm4Model(Glm4PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -718,11 +714,11 @@ class Glm4Model(Glm4PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6233b03895b..3b38c7fd034 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -820,7 +820,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -837,7 +837,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -863,7 +862,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -883,8 +881,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -896,11 +892,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e72b82eb8b5..ded5ad00d45 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -634,7 +634,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -651,7 +651,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -676,7 +675,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -696,8 +694,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -709,11 +705,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 949e26d2b2f..cd4d904b126 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -671,7 +671,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -688,7 +688,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -714,7 +713,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -734,8 +732,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -747,11 +743,11 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 50344787bc2..a6e06f80ed8 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -921,7 +921,7 @@ class GPTJModel(GPTJPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -938,7 +938,6 @@ class GPTJModel(GPTJPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -964,7 +963,6 @@ class GPTJModel(GPTJPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -984,8 +982,6 @@ class GPTJModel(GPTJPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -997,11 +993,11 @@ class GPTJModel(GPTJPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 4538a323c74..9446013cf44 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -635,7 +635,7 @@ class GraniteModel(GranitePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -652,7 +652,6 @@ class GraniteModel(GranitePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -677,7 +676,6 @@ class GraniteModel(GranitePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -697,8 +695,6 @@ class GraniteModel(GranitePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -710,11 +706,11 @@ class GraniteModel(GranitePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 40bff42cdfc..d417535db78 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1120,7 +1120,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1137,7 +1137,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1163,7 +1162,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1183,8 +1181,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1196,11 +1192,11 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index fe62089a005..59d679fb3c2 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -1065,7 +1065,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1082,7 +1082,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1107,7 +1106,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1127,8 +1125,6 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1140,11 +1136,11 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 9f7b4b87933..90d0c7011bd 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -620,7 +620,7 @@ class HeliumModel(HeliumPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -637,7 +637,6 @@ class HeliumModel(HeliumPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -662,7 +661,6 @@ class HeliumModel(HeliumPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -682,8 +680,6 @@ class HeliumModel(HeliumPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -695,11 +691,11 @@ class HeliumModel(HeliumPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 73d3ac4fab6..9b7d2603004 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1416,7 +1416,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1433,7 +1433,6 @@ class IdeficsModel(IdeficsPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1459,7 +1458,6 @@ class IdeficsModel(IdeficsPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1479,8 +1477,6 @@ class IdeficsModel(IdeficsPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1492,11 +1488,11 @@ class IdeficsModel(IdeficsPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bca39ba83e2..cabebb90ef3 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1123,7 +1123,7 @@ class JetMoeModel(JetMoePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1140,7 +1140,6 @@ class JetMoeModel(JetMoePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1166,7 +1165,6 @@ class JetMoeModel(JetMoePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1186,8 +1184,6 @@ class JetMoeModel(JetMoePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1199,11 +1195,11 @@ class JetMoeModel(JetMoePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f6d5471ce3d..86a4613b155 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -625,7 +625,7 @@ class LlamaModel(LlamaPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -642,7 +642,6 @@ class LlamaModel(LlamaPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -667,7 +666,6 @@ class LlamaModel(LlamaPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -687,8 +685,6 @@ class LlamaModel(LlamaPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -700,11 +696,11 @@ class LlamaModel(LlamaPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 3718a43e24a..d1611a2fe09 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -785,7 +785,6 @@ class Llama4TextModel(Llama4PreTrainedModel): sequence_length=sequence_length, target_length=max(full_cache_length, attention_chunk_size), dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1846,7 +1845,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1866,8 +1864,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1879,11 +1875,11 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index f44b28f9017..b4434eae50a 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1631,7 +1631,7 @@ class LongT5Stack(LongT5PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1648,7 +1648,6 @@ class LongT5Stack(LongT5PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1674,7 +1673,6 @@ class LongT5Stack(LongT5PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1694,8 +1692,6 @@ class LongT5Stack(LongT5PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1707,11 +1703,11 @@ class LongT5Stack(LongT5PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index a1006ed011b..b8f24040ddb 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1086,7 +1086,7 @@ class MimiTransformerModel(nn.Module): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1106,7 +1106,6 @@ class MimiTransformerModel(nn.Module): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1133,7 +1132,6 @@ class MimiTransformerModel(nn.Module): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MimiConfig, @@ -1152,8 +1150,6 @@ class MimiTransformerModel(nn.Module): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1169,14 +1165,16 @@ class MimiTransformerModel(nn.Module): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ebd0fe6c018..35759852323 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -609,7 +609,7 @@ class MistralModel(MistralPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -629,7 +629,6 @@ class MistralModel(MistralPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -655,7 +654,6 @@ class MistralModel(MistralPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MistralConfig, @@ -674,8 +672,6 @@ class MistralModel(MistralPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -691,14 +687,16 @@ class MistralModel(MistralPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 84062fde3e1..0d525407152 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -171,7 +171,7 @@ class MistralModel(LlamaModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -191,7 +191,6 @@ class MistralModel(LlamaModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -217,7 +216,6 @@ class MistralModel(LlamaModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MistralConfig, @@ -236,8 +234,6 @@ class MistralModel(LlamaModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -253,14 +249,16 @@ class MistralModel(LlamaModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 377e18b875e..e8012a8d9ef 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -751,7 +751,7 @@ class MixtralModel(MixtralPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -771,7 +771,6 @@ class MixtralModel(MixtralPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -797,7 +796,6 @@ class MixtralModel(MixtralPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MixtralConfig, @@ -816,8 +814,6 @@ class MixtralModel(MixtralPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -833,14 +829,16 @@ class MixtralModel(MixtralPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index db961a30dc8..85a2ccd5ecf 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1092,7 +1092,7 @@ class MllamaPreTrainedModel(PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1109,7 +1109,6 @@ class MllamaPreTrainedModel(PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1135,7 +1134,6 @@ class MllamaPreTrainedModel(PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1155,8 +1153,6 @@ class MllamaPreTrainedModel(PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1168,11 +1164,11 @@ class MllamaPreTrainedModel(PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 6d10364a6e9..3d9b96ecdb4 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -968,7 +968,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -985,7 +985,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1010,7 +1009,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1030,8 +1028,6 @@ class MoonshineDecoder(MoonshinePreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1043,11 +1039,11 @@ class MoonshineDecoder(MoonshinePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 6e990c44028..6fa1c0c5e4e 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1312,7 +1312,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1332,7 +1332,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1359,7 +1358,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MoshiDepthConfig, @@ -1378,8 +1376,6 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1395,14 +1391,16 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) @@ -1630,7 +1628,7 @@ class MoshiModel(MoshiPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1650,7 +1648,6 @@ class MoshiModel(MoshiPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1677,7 +1674,6 @@ class MoshiModel(MoshiPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: MoshiConfig, @@ -1696,8 +1692,6 @@ class MoshiModel(MoshiPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1713,14 +1707,16 @@ class MoshiModel(MoshiPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 27badbbeeee..d2d615dbfd8 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1222,7 +1222,7 @@ class MT5Stack(MT5PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1239,7 +1239,6 @@ class MT5Stack(MT5PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1265,7 +1264,6 @@ class MT5Stack(MT5PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1285,8 +1283,6 @@ class MT5Stack(MT5PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1298,11 +1294,11 @@ class MT5Stack(MT5PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d33cb3a24b0..6d1b9609e8f 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -882,7 +882,7 @@ class NemotronModel(NemotronPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -899,7 +899,6 @@ class NemotronModel(NemotronPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -925,7 +924,6 @@ class NemotronModel(NemotronPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -945,8 +943,6 @@ class NemotronModel(NemotronPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -958,11 +954,11 @@ class NemotronModel(NemotronPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6de200e6628..5ffcd27a236 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -595,7 +595,7 @@ class OlmoModel(OlmoPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -612,7 +612,6 @@ class OlmoModel(OlmoPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -637,7 +636,6 @@ class OlmoModel(OlmoPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -657,8 +655,6 @@ class OlmoModel(OlmoPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -670,11 +666,11 @@ class OlmoModel(OlmoPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 31e6805cfde..e2c246fa1ad 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -601,7 +601,7 @@ class Olmo2Model(Olmo2PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -618,7 +618,6 @@ class Olmo2Model(Olmo2PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -643,7 +642,6 @@ class Olmo2Model(Olmo2PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -663,8 +661,6 @@ class Olmo2Model(Olmo2PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -676,11 +672,11 @@ class Olmo2Model(Olmo2PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 7e097cca311..61bec50b67a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -672,7 +672,7 @@ class OPTDecoder(OPTPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -689,7 +689,6 @@ class OPTDecoder(OPTPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -715,7 +714,6 @@ class OPTDecoder(OPTPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -735,8 +733,6 @@ class OPTDecoder(OPTPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -748,11 +744,11 @@ class OPTDecoder(OPTPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index ade69444f46..39b63c4c406 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -49,7 +49,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, @@ -70,8 +69,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): @@ -85,7 +82,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: if is_training: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 564865ddef6..8789c205bed 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -682,7 +682,7 @@ class PersimmonModel(PersimmonPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -699,7 +699,6 @@ class PersimmonModel(PersimmonPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -725,7 +724,6 @@ class PersimmonModel(PersimmonPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -745,8 +743,6 @@ class PersimmonModel(PersimmonPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -758,11 +754,11 @@ class PersimmonModel(PersimmonPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8b8bd4b8e82..aa746de1ff2 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -606,7 +606,7 @@ class PhiModel(PhiPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -623,7 +623,6 @@ class PhiModel(PhiPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -648,7 +647,6 @@ class PhiModel(PhiPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -668,8 +666,6 @@ class PhiModel(PhiPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -681,11 +677,11 @@ class PhiModel(PhiPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index aaf5332c717..03422502e11 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -664,7 +664,7 @@ class Phi3Model(Phi3PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -684,7 +684,6 @@ class Phi3Model(Phi3PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -710,7 +709,6 @@ class Phi3Model(Phi3PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Phi3Config, @@ -729,8 +727,6 @@ class Phi3Model(Phi3PreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -746,14 +742,16 @@ class Phi3Model(Phi3PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 3b9979edf31..72558e3f7c8 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1949,7 +1949,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1969,7 +1969,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1995,7 +1994,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Phi4MultimodalConfig, @@ -2014,8 +2012,6 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -2031,14 +2027,16 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 683aa431cef..388505cf33e 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1221,7 +1221,7 @@ class PhimoeModel(PhimoePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1241,7 +1241,6 @@ class PhimoeModel(PhimoePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1268,7 +1267,6 @@ class PhimoeModel(PhimoePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: PhimoeConfig, @@ -1287,8 +1285,6 @@ class PhimoeModel(PhimoePreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1304,14 +1300,16 @@ class PhimoeModel(PhimoePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 192f4a10b12..63250a47f60 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1618,7 +1618,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1635,7 +1635,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1661,7 +1660,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1681,8 +1679,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1694,11 +1690,11 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 0a8b4795543..6080a710c05 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1031,7 +1031,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1048,7 +1048,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1074,7 +1073,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1094,8 +1092,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1107,11 +1103,11 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 46fb6326ada..ea70a06ee3c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -622,7 +622,7 @@ class Qwen2Model(Qwen2PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -642,7 +642,6 @@ class Qwen2Model(Qwen2PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -668,7 +667,6 @@ class Qwen2Model(Qwen2PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2Config, @@ -687,8 +685,6 @@ class Qwen2Model(Qwen2PreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -704,14 +700,16 @@ class Qwen2Model(Qwen2PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 18b23b3949f..e79f641de3b 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2088,7 +2088,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -2108,7 +2108,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -2134,7 +2133,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2_5OmniConfig, @@ -2153,8 +2151,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -2170,14 +2166,16 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) @@ -2806,7 +2804,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -2826,7 +2824,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -2852,7 +2849,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2_5OmniConfig, @@ -2871,8 +2867,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -2888,14 +2882,16 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 4da0f59bf46..6c19b9fdcfb 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1266,7 +1266,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1286,7 +1286,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1312,7 +1311,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2_5_VLConfig, @@ -1331,8 +1329,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1348,14 +1344,16 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d780d6051ab..00cb7896fd1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1082,7 +1082,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1102,7 +1102,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1129,7 +1128,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2MoeConfig, @@ -1148,8 +1146,6 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1165,14 +1161,16 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index ebc3740c357..ce008fdaf9f 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1218,7 +1218,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -1238,7 +1238,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -1265,7 +1264,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen2VLConfig, @@ -1284,8 +1282,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1301,14 +1297,16 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) @@ -1692,9 +1690,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 2a869fe6304..d44eb4fb9ae 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -649,7 +649,7 @@ class Qwen3Model(Qwen3PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -669,7 +669,6 @@ class Qwen3Model(Qwen3PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -695,7 +694,6 @@ class Qwen3Model(Qwen3PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen3Config, @@ -714,8 +712,6 @@ class Qwen3Model(Qwen3PreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -731,14 +727,16 @@ class Qwen3Model(Qwen3PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index daf8335acd8..9a476e07689 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -765,7 +765,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -785,7 +785,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -811,7 +810,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen3MoeConfig, @@ -830,8 +828,6 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -847,14 +843,16 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 85e20b44932..c09e4962a0d 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -936,7 +936,7 @@ class StableLmModel(StableLmPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -953,7 +953,6 @@ class StableLmModel(StableLmPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -979,7 +978,6 @@ class StableLmModel(StableLmPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -999,8 +997,6 @@ class StableLmModel(StableLmPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1012,11 +1008,11 @@ class StableLmModel(StableLmPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 84e420607b9..b56a4fcf68b 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -613,7 +613,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache @@ -633,7 +633,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, @@ -659,7 +658,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Starcoder2Config, @@ -678,8 +676,6 @@ class Starcoder2Model(Starcoder2PreTrainedModel): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -695,14 +691,16 @@ class Starcoder2Model(Starcoder2PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3dd8a139875..396c448cc74 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1165,7 +1165,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1182,7 +1182,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1208,7 +1207,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1228,8 +1226,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1241,11 +1237,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e8085a8f456..8f6b5de8081 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1236,7 +1236,7 @@ class T5Stack(T5PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1253,7 +1253,6 @@ class T5Stack(T5PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1279,7 +1278,6 @@ class T5Stack(T5PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1299,8 +1297,6 @@ class T5Stack(T5PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1312,11 +1308,11 @@ class T5Stack(T5PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 9cd7f984237..665eedc8106 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1568,7 +1568,7 @@ class UdopStack(UdopPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1585,7 +1585,6 @@ class UdopStack(UdopPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1611,7 +1610,6 @@ class UdopStack(UdopPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1631,8 +1629,6 @@ class UdopStack(UdopPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1644,11 +1640,11 @@ class UdopStack(UdopPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 44586243c3f..18fb5cc5f04 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -879,7 +879,7 @@ class UMT5Stack(UMT5PreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -896,7 +896,6 @@ class UMT5Stack(UMT5PreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -922,7 +921,6 @@ class UMT5Stack(UMT5PreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -942,8 +940,6 @@ class UMT5Stack(UMT5PreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -955,11 +951,11 @@ class UMT5Stack(UMT5PreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b7a454839c0..8c5432c5c65 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1406,7 +1406,7 @@ class WhisperDecoder(WhisperPreTrainedModel): ): return None - dtype, device = input_tensor.dtype, input_tensor.device + dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1423,7 +1423,6 @@ class WhisperDecoder(WhisperPreTrainedModel): sequence_length=sequence_length, target_length=target_length, dtype=dtype, - device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) @@ -1449,7 +1448,6 @@ class WhisperDecoder(WhisperPreTrainedModel): sequence_length: int, target_length: int, dtype: torch.dtype, - device: torch.device, cache_position: torch.Tensor, batch_size: int, **kwargs, @@ -1469,8 +1467,6 @@ class WhisperDecoder(WhisperPreTrainedModel): to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): @@ -1482,11 +1478,11 @@ class WhisperDecoder(WhisperPreTrainedModel): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b9916e8dfa8..03601a6d72a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -4018,6 +4018,34 @@ class GenerationIntegrationTests(unittest.TestCase): value_cache_1 = results.past_key_values.value_cache[1] self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) + @pytest.mark.generate + @require_torch_multi_gpu + def test_generate_multi_gpu_causal_mask(self): + """ + Tests that cache position device doesn't clash with causal mask device when we are using multi-gpus. + In real life happens only when multimodal encoder size is big, so `embed_tokens` gets allocated to the next device. + The error will be triggered whenever a bacthed input is used, so that `causal_mask` is actually prepared instead of + being `None`. + """ + # need to split manually as auto doesn't work well with unbalanced model + device_map = { + "visual": 0, + "model.embed_tokens": 1, + "model.layers.0": 1, + "model.layers.1": 1, + "model.rotary_emb": 1, + "model.norm.weight": 1, + "lm_head": 1, + } + model = AutoModelForImageTextToText.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", device_map=device_map + ) + processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + text = ["Hello world", "Today I went to the supermarket to buy"] + inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device) + _ = model.generate(**inputs, max_new_tokens=20) + @pytest.mark.generate @require_torch_multi_gpu def test_init_static_cache_multi_gpu(self):