align xpu's autocast behavior w/ cuda by using device agnostic torch APIs (#38284)

* siwtch to device agnostic autocast in nemotron to align xpu behavior w/
cuda

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix issue

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* use torch.cast as other modeling code for decision_transformer&gpt2&imagegpt

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* refine

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update get_autocast_gpu_dtype to device agnostic one

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-06-19 19:48:23 +08:00 committed by GitHub
parent 0a53df1a77
commit a9ce8c69c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 138 additions and 37 deletions

View File

@ -100,7 +100,7 @@ class ChameleonRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285 # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)

View File

@ -64,7 +64,7 @@ class DbrxRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285 # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
@ -387,9 +387,14 @@ class DbrxFlashAttention2(DbrxAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -219,7 +219,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False): with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -306,9 +306,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
# in fp32. (DiffLlamaRMSNorm handles it correctly) # in fp32. (DiffLlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -239,9 +239,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
# in fp32. (DiffLlamaRMSNorm handles it correctly) # in fp32. (DiffLlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -289,9 +289,14 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if query_states.dtype == torch.float32: if query_states.dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -459,9 +459,14 @@ class EsmFlashAttention2(EsmSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. # in fp32.
input_dtype = query_layer.dtype input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput):
max_predicted_aligned_error: Optional[torch.FloatTensor] = None max_predicted_aligned_error: Optional[torch.FloatTensor] = None
def is_fp16_enabled(): def is_fp16_enabled(device_type):
# Autocast world # Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 autocast_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
fp16_enabled = autocast_dtype == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled() fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled return fp16_enabled
@ -885,8 +890,9 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
if is_fp16_enabled(): device_type = a.device.type if a.device.type != "mps" else "cpu"
with torch.cuda.amp.autocast(enabled=False): if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
x = self._combine_projections(a.float(), b.float()) x = self._combine_projections(a.float(), b.float())
else: else:
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
@ -1499,8 +1505,9 @@ class EsmFoldInvariantPointAttention(nn.Module):
z[0] = z[0].cpu() z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
if is_fp16_enabled(): device_type = q.device.type if q.device.type != "mps" else "cpu"
with torch.cuda.amp.autocast(enabled=False): if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
a = torch.matmul( a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]

View File

@ -488,9 +488,14 @@ class FalconFlashAttention2(FalconAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -229,7 +229,7 @@ class GPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False): with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -343,9 +343,14 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -323,9 +323,14 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
device_type = query.device.type if query.device.type != "mps" else "cpu"
if query.dtype == torch.float32: if query.dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -355,9 +355,14 @@ class GPTJFlashAttention2(GPTJAttention):
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query.dtype input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -22,7 +22,6 @@ from typing import Any, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
@ -280,7 +279,7 @@ class ImageGPTAttention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False): with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -412,9 +412,14 @@ class JambaFlashAttention2(JambaAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -710,9 +710,14 @@ class JetMoeFlashAttention2(JetMoeAttention):
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -661,9 +661,14 @@ class MimiFlashAttention2(MimiAttention):
# in fp32. (MimiRMSNorm handles it correctly) # in fp32. (MimiRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -594,9 +594,14 @@ class MoshiFlashAttention2(MoshiAttention):
# in fp32. (MoshiRMSNorm handles it correctly) # in fp32. (MoshiRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -51,11 +51,17 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _cast_if_autocast_enabled(*args): def _cast_if_autocast_enabled(device_type, *args):
if not torch.is_autocast_enabled(): if not torch.is_autocast_enabled():
return args return args
else: else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
return torch.amp.autocast_mode._cast(args, device_type, target_dtype)
class NemotronLayerNorm1P(nn.LayerNorm): class NemotronLayerNorm1P(nn.LayerNorm):
@ -71,8 +77,11 @@ class NemotronLayerNorm1P(nn.LayerNorm):
super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) device_type = input.device.type if input.device.type != "mps" else "cpu"
with torch.amp.autocast(input.device.type, enabled=False): args = _cast_if_autocast_enabled(
device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
)
with torch.autocast(device_type=input.device.type, enabled=False):
return F.layer_norm(*args) return F.layer_norm(*args)
@ -344,9 +353,15 @@ class NemotronFlashAttention2(NemotronAttention):
# in fp32. (NemotronRMSNorm handles it correctly) # in fp32. (NemotronRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -421,9 +421,14 @@ class OlmoeFlashAttention2(OlmoeAttention):
# in fp32. (OlmoeRMSNorm handles it correctly) # in fp32. (OlmoeRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -369,9 +369,14 @@ class PhimoeFlashAttention2(PhimoeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -2638,7 +2638,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
batch_size, seq_len = x.shape[0], x.shape[1] batch_size, seq_len = x.shape[0], x.shape[1]
t = torch.arange(seq_len, device=x.device) t = torch.arange(seq_len, device=x.device)
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
freqs = torch.stack((freqs, freqs), dim=-1) freqs = torch.stack((freqs, freqs), dim=-1)

View File

@ -2906,7 +2906,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
batch_size, seq_len = x.shape[0], x.shape[1] batch_size, seq_len = x.shape[0], x.shape[1]
t = torch.arange(seq_len, device=x.device) t = torch.arange(seq_len, device=x.device)
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
freqs = torch.stack((freqs, freqs), dim=-1) freqs = torch.stack((freqs, freqs), dim=-1)

View File

@ -418,9 +418,14 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype

View File

@ -78,7 +78,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285 # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)

View File

@ -462,7 +462,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
torch.manual_seed(42) torch.manual_seed(42)
with torch.amp.autocast(device_type=torch_device, dtype=dtype): with torch.autocast(device_type=torch_device, dtype=dtype):
with torch.no_grad(): with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device) mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device) hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)