From 3b7675b2b844b02d4821b827871a21ad16dd446c Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Dec 2023 08:36:41 +0530 Subject: [PATCH] fix FA2 when using quantization (#28203) --- src/transformers/models/falcon/modeling_falcon.py | 8 ++++---- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 8 ++++---- src/transformers/models/llama/modeling_llama.py | 8 ++++---- src/transformers/models/mistral/modeling_mistral.py | 8 ++++---- src/transformers/models/mixtral/modeling_mixtral.py | 8 ++++---- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a92650ec219..7c2c63f5c5f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -617,11 +617,11 @@ class FalconFlashAttention2(FalconAttention): # cast them back in float16 just to be sure everything works as expected. input_dtype = query_layer.dtype if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): + if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.query_key_value.weight.dtype diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 468916c28cb..71f709e3e15 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -375,11 +375,11 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): # cast them back in float16 just to be sure everything works as expected. input_dtype = query.dtype if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): + if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.c_attn.weight.dtype diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2cbe3bc9592..5f54fea8c40 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -528,11 +528,11 @@ class LlamaFlashAttention2(LlamaAttention): input_dtype = query_states.dtype if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): + if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a83c31bce47..c8b5c211777 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -428,11 +428,11 @@ class MistralFlashAttention2(MistralAttention): # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): + if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0f6985c3085..a32b9b1457a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -477,11 +477,11 @@ class MixtralFlashAttention2(MixtralAttention): # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): + if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype