mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix FA2 when using quantization (#28203)
This commit is contained in:
parent
fa21ead73d
commit
3b7675b2b8
@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention):
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_layer.dtype
|
||||
if input_dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
elif torch.is_autocast_enabled():
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.query_key_value.weight.dtype
|
||||
|
||||
|
@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query.dtype
|
||||
if input_dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
elif torch.is_autocast_enabled():
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.c_attn.weight.dtype
|
||||
|
||||
|
@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
elif torch.is_autocast_enabled():
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
|
@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention):
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
elif torch.is_autocast_enabled():
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
|
@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
elif torch.is_autocast_enabled():
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user