introduce logger.warning_once and use it for grad checkpointing code (#21804)

* logger.warning_once

* style
This commit is contained in:
Stas Bekman 2023-02-27 13:25:06 -08:00 committed by GitHub
parent f95f60c829
commit c7f3abc257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 74 additions and 57 deletions

View File

@ -638,7 +638,7 @@ class AltRobertaEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1085,7 +1085,7 @@ class BartDecoder(BartPretrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -585,7 +585,7 @@ class BertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -395,7 +395,7 @@ class BertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1606,7 +1606,7 @@ class BigBirdEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -2265,7 +2265,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -557,7 +557,7 @@ class BioGptModel(BioGptPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1016,7 +1016,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1012,7 +1012,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -757,7 +757,7 @@ class BloomModel(BloomPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -769,7 +769,7 @@ class BridgeTowerTextEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -516,7 +516,7 @@ class CamembertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -901,7 +901,7 @@ class ChineseCLIPTextEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1588,7 +1588,7 @@ class ClapTextEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -548,7 +548,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)

View File

@ -502,7 +502,7 @@ class Data2VecTextEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -609,7 +609,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -563,7 +563,7 @@ class ElectraEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -498,7 +498,7 @@ class ErnieEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -597,7 +597,7 @@ class EsmEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)

View File

@ -444,7 +444,7 @@ class GitEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -853,7 +853,7 @@ class GPT2Model(GPT2PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -589,7 +589,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -653,7 +653,7 @@ class GPTJModel(GPTJPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -812,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -479,7 +479,7 @@ class LayoutLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -2136,7 +2136,7 @@ class LEDDecoder(LEDPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1055,7 +1055,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
" `use_cache=False`..."
)

View File

@ -1020,7 +1020,7 @@ class MarianDecoder(MarianPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -641,7 +641,7 @@ class MarkupLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1069,7 +1069,7 @@ class MBartDecoder(MBartPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -544,7 +544,7 @@ class MegatronBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1008,7 +1008,7 @@ class MT5Stack(MT5PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1212,7 +1212,7 @@ class MvpDecoder(MvpPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -571,7 +571,7 @@ class NezhaEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -671,7 +671,7 @@ class OPTDecoder(OPTPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1070,7 +1070,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1311,7 +1311,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1048,7 +1048,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1572,7 +1572,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -575,7 +575,7 @@ class QDQBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -578,7 +578,7 @@ class RealmEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -536,7 +536,7 @@ class RemBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -502,7 +502,7 @@ class RobertaEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -504,7 +504,7 @@ class RobertaPreLayerNormEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -636,7 +636,7 @@ class RoCBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -573,7 +573,7 @@ class RoFormerEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1692,7 +1692,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -451,7 +451,7 @@ class SplinterEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1057,7 +1057,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1037,7 +1037,7 @@ class T5Stack(T5PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1471,7 +1471,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -543,7 +543,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -1595,7 +1595,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -503,7 +503,7 @@ class XLMRobertaEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -492,7 +492,7 @@ class XLMRobertaXLEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -566,7 +566,7 @@ class XmodEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

View File

@ -14,6 +14,8 @@
# limitations under the License.
""" Logging utilities."""
import functools
import logging
import os
import sys
@ -281,6 +283,21 @@ def warning_advice(self, *args, **kwargs):
logging.Logger.warning_advice = warning_advice
@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
logging.Logger.warning_once = warning_once
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""