mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix duplicate & unnecessary flash attention warnings (#28557)
* fix duplicate & unnecessary flash warnings * trigger ci * warning_once * if/else order --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
142ce68389
commit
8eb74c1c89
@ -1321,7 +1321,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", None)
|
||||
config = cls._autoset_attn_implementation(
|
||||
config, use_flash_attention_2=use_flash_attention_2, check_device_map=False
|
||||
config,
|
||||
use_flash_attention_2=use_flash_attention_2,
|
||||
check_device_map=False,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@ -1396,7 +1399,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif requested_attn_implementation in [None, "sdpa"]:
|
||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(
|
||||
config, hard_check_only=False if requested_attn_implementation is None else True
|
||||
config,
|
||||
hard_check_only=False if requested_attn_implementation is None else True,
|
||||
)
|
||||
else:
|
||||
config._attn_implementation = "eager"
|
||||
@ -1503,20 +1507,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
|
||||
)
|
||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.warning(
|
||||
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
|
||||
"No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
|
||||
logger.warning_once(
|
||||
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
|
||||
if torch.cuda.is_available():
|
||||
logger.warning(
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user