align xpu's autocast behavior w/ cuda by using device agnostic torch APIs (#38284)

* siwtch to device agnostic autocast in nemotron to align xpu behavior w/
cuda

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix issue

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* use torch.cast as other modeling code for decision_transformer&gpt2&imagegpt

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* refine

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update get_autocast_gpu_dtype to device agnostic one

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-06-19 19:48:23 +08:00 committed by GitHub
parent 0a53df1a77
commit a9ce8c69c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 138 additions and 37 deletions

View File

@ -100,7 +100,7 @@ class ChameleonRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

View File

@ -64,7 +64,7 @@ class DbrxRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
@ -387,9 +387,14 @@ class DbrxFlashAttention2(DbrxAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -219,7 +219,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False):
with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -306,9 +306,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
# in fp32. (DiffLlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -239,9 +239,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
# in fp32. (DiffLlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -289,9 +289,14 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -459,9 +459,14 @@ class EsmFlashAttention2(EsmSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput):
max_predicted_aligned_error: Optional[torch.FloatTensor] = None
def is_fp16_enabled():
def is_fp16_enabled(device_type):
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
autocast_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
fp16_enabled = autocast_dtype == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
@ -885,8 +890,9 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
device_type = a.device.type if a.device.type != "mps" else "cpu"
if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
@ -1499,8 +1505,9 @@ class EsmFoldInvariantPointAttention(nn.Module):
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
device_type = q.device.type if q.device.type != "mps" else "cpu"
if is_fp16_enabled(device_type):
with torch.autocast(device_type=device_type, enabled=False):
a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]

View File

@ -488,9 +488,14 @@ class FalconFlashAttention2(FalconAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -229,7 +229,7 @@ class GPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.amp.autocast(query.device.type, enabled=False):
with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -343,9 +343,14 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -323,9 +323,14 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
device_type = query.device.type if query.device.type != "mps" else "cpu"
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -355,9 +355,14 @@ class GPTJFlashAttention2(GPTJAttention):
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query.dtype
device_type = query.device.type if query.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -22,7 +22,6 @@ from typing import Any, Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
@ -280,7 +279,7 @@ class ImageGPTAttention(nn.Module):
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

View File

@ -412,9 +412,14 @@ class JambaFlashAttention2(JambaAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -710,9 +710,14 @@ class JetMoeFlashAttention2(JetMoeAttention):
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -661,9 +661,14 @@ class MimiFlashAttention2(MimiAttention):
# in fp32. (MimiRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -594,9 +594,14 @@ class MoshiFlashAttention2(MoshiAttention):
# in fp32. (MoshiRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -51,11 +51,17 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__)
def _cast_if_autocast_enabled(*args):
def _cast_if_autocast_enabled(device_type, *args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
return torch.amp.autocast_mode._cast(args, device_type, target_dtype)
class NemotronLayerNorm1P(nn.LayerNorm):
@ -71,8 +77,11 @@ class NemotronLayerNorm1P(nn.LayerNorm):
super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
with torch.amp.autocast(input.device.type, enabled=False):
device_type = input.device.type if input.device.type != "mps" else "cpu"
args = _cast_if_autocast_enabled(
device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
)
with torch.autocast(device_type=input.device.type, enabled=False):
return F.layer_norm(*args)
@ -344,9 +353,15 @@ class NemotronFlashAttention2(NemotronAttention):
# in fp32. (NemotronRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -421,9 +421,14 @@ class OlmoeFlashAttention2(OlmoeAttention):
# in fp32. (OlmoeRMSNorm handles it correctly)
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -369,9 +369,14 @@ class PhimoeFlashAttention2(PhimoeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -2638,7 +2638,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
batch_size, seq_len = x.shape[0], x.shape[1]
t = torch.arange(seq_len, device=x.device)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
freqs = torch.stack((freqs, freqs), dim=-1)

View File

@ -2906,7 +2906,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
batch_size, seq_len = x.shape[0], x.shape[1]
t = torch.arange(seq_len, device=x.device)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
freqs = torch.stack((freqs, freqs), dim=-1)

View File

@ -418,9 +418,14 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = (
torch.get_autocast_dtype(device_type)
if hasattr(torch, "get_autocast_dtype")
else torch.get_autocast_gpu_dtype()
)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@ -78,7 +78,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
device_type = device_type if device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

View File

@ -462,7 +462,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
torch.manual_seed(42)
with torch.amp.autocast(device_type=torch_device, dtype=dtype):
with torch.autocast(device_type=torch_device, dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)