diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index d6575a8751d..64148e2457f 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -100,7 +100,7 @@ class ChameleonRotaryEmbedding(nn.Module): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 097370a46cd..f1fee43a239 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -64,7 +64,7 @@ class DbrxRotaryEmbedding(nn.Module): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -387,9 +387,14 @@ class DbrxFlashAttention2(DbrxAttention): # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index d34e989986d..a436ce6d2c0 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -219,7 +219,7 @@ class DecisionTransformerGPT2Attention(nn.Module): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 43e8e4584bd..fae9f2dbb95 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -306,9 +306,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 934baf9c0b1..0ff0465c793 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -239,9 +239,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): # in fp32. (DiffLlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 1a84544bee3..28cec74fb3d 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -289,9 +289,14 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if query_states.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index dbf260fb216..10db78a67cb 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -459,9 +459,14 @@ class EsmFlashAttention2(EsmSelfAttention): # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. input_dtype = query_layer.dtype + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 3338b3d5007..644fec2e1a3 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput): max_predicted_aligned_error: Optional[torch.FloatTensor] = None -def is_fp16_enabled(): +def is_fp16_enabled(device_type): # Autocast world - fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + autocast_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + fp16_enabled = autocast_dtype == torch.float16 fp16_enabled = fp16_enabled and torch.is_autocast_enabled() return fp16_enabled @@ -885,8 +890,9 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module): b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.linear_b_p(z) - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = a.device.type if a.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) else: x = self._combine_projections(a, b) @@ -1499,8 +1505,9 @@ class EsmFoldInvariantPointAttention(nn.Module): z[0] = z[0].cpu() # [*, H, N_res, N_res] - if is_fp16_enabled(): - with torch.cuda.amp.autocast(enabled=False): + device_type = q.device.type if q.device.type != "mps" else "cpu" + if is_fp16_enabled(device_type): + with torch.autocast(device_type=device_type, enabled=False): a = torch.matmul( permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 455afbb2157..d6634662f30 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -488,9 +488,14 @@ class FalconFlashAttention2(FalconAttention): # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype + device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 584d21c4108..fa98bc3614e 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -229,7 +229,7 @@ class GPT2Attention(nn.Module): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.amp.autocast(query.device.type, enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 297c30cb06a..b90fdfe8acc 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -343,9 +343,14 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 417eb11ab0f..8ac65c7d1ae 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -323,9 +323,14 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + device_type = query.device.type if query.device.type != "mps" else "cpu" if query.dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1e9c0ef5332..4388fad01f6 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,9 +355,14 @@ class GPTJFlashAttention2(GPTJAttention): # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype + device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index db5ae763aad..65d5cbc3df2 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -22,7 +22,6 @@ from typing import Any, Optional, Union import torch import torch.utils.checkpoint from torch import nn -from torch.cuda.amp import autocast from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -280,7 +279,7 @@ class ImageGPTAttention(nn.Module): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): + with torch.autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 52570d8f7f8..0b93d4484c9 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -412,9 +412,14 @@ class JambaFlashAttention2(JambaAttention): # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index ccee5f34a50..4b6fefbf9e9 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -710,9 +710,14 @@ class JetMoeFlashAttention2(JetMoeAttention): # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 9023c93433a..45c307b6136 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -661,9 +661,14 @@ class MimiFlashAttention2(MimiAttention): # in fp32. (MimiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index c0ef9c00147..6eca81ef497 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -594,9 +594,14 @@ class MoshiFlashAttention2(MoshiAttention): # in fp32. (MoshiRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 6b5cf370c9d..d9d0248a33d 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -51,11 +51,17 @@ if is_torch_flex_attn_available(): logger = logging.get_logger(__name__) -def _cast_if_autocast_enabled(*args): +def _cast_if_autocast_enabled(device_type, *args): if not torch.is_autocast_enabled(): return args else: - return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + return torch.amp.autocast_mode._cast(args, device_type, target_dtype) class NemotronLayerNorm1P(nn.LayerNorm): @@ -71,8 +77,11 @@ class NemotronLayerNorm1P(nn.LayerNorm): super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) def forward(self, input: Tensor) -> Tensor: - args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.amp.autocast(input.device.type, enabled=False): + device_type = input.device.type if input.device.type != "mps" else "cpu" + args = _cast_if_autocast_enabled( + device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) + with torch.autocast(device_type=input.device.type, enabled=False): return F.layer_norm(*args) @@ -344,9 +353,15 @@ class NemotronFlashAttention2(NemotronAttention): # in fp32. (NemotronRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4 + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 6dc3c12c1ff..a9f2a08124e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -421,9 +421,14 @@ class OlmoeFlashAttention2(OlmoeAttention): # in fp32. (OlmoeRMSNorm handles it correctly) input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 2a924078339..b1651f467b4 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -369,9 +369,14 @@ class PhimoeFlashAttention2(PhimoeAttention): # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c4e151e9ce1..1ccc6ea0bfb 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2638,7 +2638,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 10edb4e6a43..3779a2ad061 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2906,7 +2906,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): batch_size, seq_len = x.shape[0], x.shape[1] t = torch.arange(seq_len, device=x.device) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = torch.stack((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 9882ca447d7..cc617533582 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -418,9 +418,14 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 47e79f34870..6723c520f28 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -78,7 +78,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module): # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 5777053923a..dfa8bca69ef 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -462,7 +462,7 @@ class Mamba2IntegrationTest(unittest.TestCase): config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) torch.manual_seed(42) - with torch.amp.autocast(device_type=torch_device, dtype=dtype): + with torch.autocast(device_type=torch_device, dtype=dtype): with torch.no_grad(): mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device) hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)