Fix duplicate & unnecessary flash attention warnings (#28557)

* fix duplicate & unnecessary flash warnings

* trigger ci

* warning_once

* if/else order

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
fxmarty 2024-01-26 09:37:04 +01:00 committed by GitHub
parent 142ce68389
commit 8eb74c1c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1321,7 +1321,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
config._attn_implementation = kwargs.pop("attn_implementation", None)
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
config,
use_flash_attention_2=use_flash_attention_2,
check_device_map=False,
torch_dtype=torch_dtype,
)
if is_deepspeed_zero3_enabled():
@ -1396,7 +1399,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif requested_attn_implementation in [None, "sdpa"]:
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config, hard_check_only=False if requested_attn_implementation is None else True
config,
hard_check_only=False if requested_attn_implementation is None else True,
)
else:
config._attn_implementation = "eager"
@ -1503,20 +1507,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if torch_dtype is None:
logger.warning(
logger.warning_once(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
logger.warning(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
"No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
logger.warning_once(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
)
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if torch.cuda.is_available():
logger.warning(
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)