mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
align xpu's autocast behavior w/ cuda by using device agnostic torch APIs (#38284)
* siwtch to device agnostic autocast in nemotron to align xpu behavior w/ cuda Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix issue Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * use torch.cast as other modeling code for decision_transformer&gpt2&imagegpt Signed-off-by: Matrix Yao <matrix.yao@intel.com> * refine Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update get_autocast_gpu_dtype to device agnostic one Signed-off-by: Matrix YAO <matrix.yao@intel.com> * fix style Signed-off-by: Matrix YAO <matrix.yao@intel.com> * fix comments Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: Matrix YAO <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0a53df1a77
commit
a9ce8c69c9
@ -100,7 +100,7 @@ class ChameleonRotaryEmbedding(nn.Module):
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
device_type = device_type if device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
@ -64,7 +64,7 @@ class DbrxRotaryEmbedding(nn.Module):
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
device_type = device_type if device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
@ -387,9 +387,14 @@ class DbrxFlashAttention2(DbrxAttention):
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -219,7 +219,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
with torch.amp.autocast(query.device.type, enabled=False):
|
||||
with torch.autocast(query.device.type, enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
@ -306,9 +306,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
# in fp32. (DiffLlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -239,9 +239,14 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
# in fp32. (DiffLlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -289,9 +289,14 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if query_states.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -459,9 +459,14 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
input_dtype = query_layer.dtype
|
||||
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -133,9 +133,14 @@ class EsmForProteinFoldingOutput(ModelOutput):
|
||||
max_predicted_aligned_error: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def is_fp16_enabled():
|
||||
def is_fp16_enabled(device_type):
|
||||
# Autocast world
|
||||
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
|
||||
autocast_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
fp16_enabled = autocast_dtype == torch.float16
|
||||
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
|
||||
|
||||
return fp16_enabled
|
||||
@ -885,8 +890,9 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
|
||||
b = b * self.sigmoid(self.linear_b_g(z))
|
||||
b = b * self.linear_b_p(z)
|
||||
|
||||
if is_fp16_enabled():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
device_type = a.device.type if a.device.type != "mps" else "cpu"
|
||||
if is_fp16_enabled(device_type):
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
x = self._combine_projections(a.float(), b.float())
|
||||
else:
|
||||
x = self._combine_projections(a, b)
|
||||
@ -1499,8 +1505,9 @@ class EsmFoldInvariantPointAttention(nn.Module):
|
||||
z[0] = z[0].cpu()
|
||||
|
||||
# [*, H, N_res, N_res]
|
||||
if is_fp16_enabled():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
device_type = q.device.type if q.device.type != "mps" else "cpu"
|
||||
if is_fp16_enabled(device_type):
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
a = torch.matmul(
|
||||
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
||||
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|
||||
|
@ -488,9 +488,14 @@ class FalconFlashAttention2(FalconAttention):
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_layer.dtype
|
||||
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -229,7 +229,7 @@ class GPT2Attention(nn.Module):
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
with torch.amp.autocast(query.device.type, enabled=False):
|
||||
with torch.autocast(query.device.type, enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
@ -343,9 +343,14 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query.dtype
|
||||
device_type = query.device.type if query.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -323,9 +323,14 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
device_type = query.device.type if query.device.type != "mps" else "cpu"
|
||||
if query.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -355,9 +355,14 @@ class GPTJFlashAttention2(GPTJAttention):
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query.dtype
|
||||
device_type = query.device.type if query.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -22,7 +22,6 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -280,7 +279,7 @@ class ImageGPTAttention(nn.Module):
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
with autocast(enabled=False):
|
||||
with torch.autocast(query.device.type, enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
@ -412,9 +412,14 @@ class JambaFlashAttention2(JambaAttention):
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -710,9 +710,14 @@ class JetMoeFlashAttention2(JetMoeAttention):
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -661,9 +661,14 @@ class MimiFlashAttention2(MimiAttention):
|
||||
# in fp32. (MimiRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -594,9 +594,14 @@ class MoshiFlashAttention2(MoshiAttention):
|
||||
# in fp32. (MoshiRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -51,11 +51,17 @@ if is_torch_flex_attn_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _cast_if_autocast_enabled(*args):
|
||||
def _cast_if_autocast_enabled(device_type, *args):
|
||||
if not torch.is_autocast_enabled():
|
||||
return args
|
||||
else:
|
||||
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
|
||||
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
return torch.amp.autocast_mode._cast(args, device_type, target_dtype)
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
@ -71,8 +77,11 @@ class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
|
||||
with torch.amp.autocast(input.device.type, enabled=False):
|
||||
device_type = input.device.type if input.device.type != "mps" else "cpu"
|
||||
args = _cast_if_autocast_enabled(
|
||||
device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
|
||||
)
|
||||
with torch.autocast(device_type=input.device.type, enabled=False):
|
||||
return F.layer_norm(*args)
|
||||
|
||||
|
||||
@ -344,9 +353,15 @@ class NemotronFlashAttention2(NemotronAttention):
|
||||
# in fp32. (NemotronRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -421,9 +421,14 @@ class OlmoeFlashAttention2(OlmoeAttention):
|
||||
# in fp32. (OlmoeRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -369,9 +369,14 @@ class PhimoeFlashAttention2(PhimoeAttention):
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -2638,7 +2638,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
|
||||
batch_size, seq_len = x.shape[0], x.shape[1]
|
||||
t = torch.arange(seq_len, device=x.device)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
device_type = device_type if device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
|
||||
freqs = torch.stack((freqs, freqs), dim=-1)
|
||||
|
@ -2906,7 +2906,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
|
||||
batch_size, seq_len = x.shape[0], x.shape[1]
|
||||
t = torch.arange(seq_len, device=x.device)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
device_type = device_type if device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
|
||||
freqs = torch.stack((freqs, freqs), dim=-1)
|
||||
|
@ -418,9 +418,14 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = (
|
||||
torch.get_autocast_dtype(device_type)
|
||||
if hasattr(torch, "get_autocast_dtype")
|
||||
else torch.get_autocast_gpu_dtype()
|
||||
)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
@ -78,7 +78,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
device_type = device_type if device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
@ -462,7 +462,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||
|
||||
torch.manual_seed(42)
|
||||
with torch.amp.autocast(device_type=torch_device, dtype=dtype):
|
||||
with torch.autocast(device_type=torch_device, dtype=dtype):
|
||||
with torch.no_grad():
|
||||
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user