mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[debug utils] activation/weights underflow/overflow detector (#11274)
* sync * add activation overflow debug utility * cleanup * document detect_overflow * import torch * add deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * convert to rst, add note * add class * fix docs * improve the doc * rework to dump a lot more info about each frame * complete expansion * cleanup * format * cleanup * doesn't have to be transformers * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * wrap long line * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
804c2974d5
commit
282f3ac3ef
295
docs/source/debugging.rst
Normal file
295
docs/source/debugging.rst
Normal file
@ -0,0 +1,295 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
|
||||
|
||||
Debugging
|
||||
=======================================================================================================================
|
||||
|
||||
Underflow and Overflow Detection
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
This feature is currently available for PyTorch-only.
|
||||
|
||||
.. note::
|
||||
|
||||
This feature can be used with any ``nn.Module``-based model
|
||||
|
||||
If you start getting ``loss=NaN`` or the model inhibits some other abnormal behavior due to ``inf`` or ``nan`` in
|
||||
activations or weights one needs to discover where the first underflow or overflow happens and what led to it. Luckily
|
||||
you can accomplish that easily by activating a special module that will do the detection automatically.
|
||||
|
||||
If you're using :class:`~transformers.Trainer`, you just need to add:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--debug underflow_overflow
|
||||
|
||||
to the normal command line arguments, or pass ``debug="underflow_overflow"`` when creating the
|
||||
:class:`~transformers.TrainingArguments` object.
|
||||
|
||||
If you're using your own training loop or another Trainer you can accomplish the same with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from .debug_utils import DebugUnderflowOverflow
|
||||
debug_overflow = DebugUnderflowOverflow(model)
|
||||
|
||||
:class:`~transformers.debug_utils.DebugUnderflowOverflow` inserts hooks into the model that immediately after each
|
||||
forward call will test input and output variables and also the corresponding module's weights. As soon as ``inf`` or
|
||||
``nan`` is detected in at least one element of the activations or weights, the program will assert and print a report
|
||||
like this (this was caught with ``google/mt5-small`` under fp16 mixed precision):
|
||||
|
||||
.. code-block::
|
||||
|
||||
Detected inf/nan during batch_number=0
|
||||
Last 21 forward frames:
|
||||
abs min abs max metadata
|
||||
encoder.block.1.layer.1.DenseReluDense.dropout Dropout
|
||||
0.00e+00 2.57e+02 input[0]
|
||||
0.00e+00 2.85e+02 output
|
||||
[...]
|
||||
encoder.block.2.layer.0 T5LayerSelfAttention
|
||||
6.78e-04 3.15e+03 input[0]
|
||||
2.65e-04 3.42e+03 output[0]
|
||||
None output[1]
|
||||
2.25e-01 1.00e+04 output[2]
|
||||
encoder.block.2.layer.1.layer_norm T5LayerNorm
|
||||
8.69e-02 4.18e-01 weight
|
||||
2.65e-04 3.42e+03 input[0]
|
||||
1.79e-06 4.65e+00 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
||||
2.17e-07 4.50e+00 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
2.68e-06 3.70e+01 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
||||
8.08e-07 2.66e+01 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
1.27e-04 2.37e+02 output
|
||||
encoder.block.2.layer.1.DenseReluDense.dropout Dropout
|
||||
0.00e+00 8.76e+03 input[0]
|
||||
0.00e+00 9.74e+03 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
||||
1.01e-06 6.44e+00 weight
|
||||
0.00e+00 9.74e+03 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.dropout Dropout
|
||||
3.18e-04 6.27e+04 input[0]
|
||||
0.00e+00 inf output
|
||||
|
||||
The example output has been trimmed in the middle for brevity.
|
||||
|
||||
The second column shows the value of the absolute largest element, so if you have a closer look at the last few frames,
|
||||
the inputs and outputs were in the range of ``1e4``. So when this training was done under fp16 mixed precision the very
|
||||
last step overflowed (since under ``fp16`` the largest number before ``inf`` is ``64e3``). To avoid overflows under
|
||||
``fp16`` the activations must remain way below ``1e4``, because ``1e4 * 1e4 = 1e8`` so any matrix multiplication with
|
||||
large activations is going to lead to a numerical overflow condition.
|
||||
|
||||
At the very start of the trace you can discover at which batch number the problem occurred (here ``Detected inf/nan
|
||||
during batch_number=0`` means the problem occurred on the first batch).
|
||||
|
||||
Each reported frame starts by declaring the fully qualified entry for the corresponding module this frame is reporting
|
||||
for. If we look just at this frame:
|
||||
|
||||
.. code-block::
|
||||
|
||||
encoder.block.2.layer.1.layer_norm T5LayerNorm
|
||||
8.69e-02 4.18e-01 weight
|
||||
2.65e-04 3.42e+03 input[0]
|
||||
1.79e-06 4.65e+00 output
|
||||
|
||||
Here, ``encoder.block.2.layer.1.layer_norm`` indicates that it was a layer norm for the first layer, of the second
|
||||
block of the encoder. And the specific calls of the ``forward`` is ``T5LayerNorm``.
|
||||
|
||||
Let's look at the last few frames of that report:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Detected inf/nan during batch_number=0
|
||||
Last 21 forward frames:
|
||||
abs min abs max metadata
|
||||
[...]
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
||||
2.17e-07 4.50e+00 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
2.68e-06 3.70e+01 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
||||
8.08e-07 2.66e+01 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
1.27e-04 2.37e+02 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
||||
1.01e-06 6.44e+00 weight
|
||||
0.00e+00 9.74e+03 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.dropout Dropout
|
||||
3.18e-04 6.27e+04 input[0]
|
||||
0.00e+00 inf output
|
||||
|
||||
The last frame reports for ``Dropout.forward`` function with the first entry for the only input and the second for the
|
||||
only output. You can see that it was called from an attribute ``dropout`` inside ``DenseReluDense`` class. We can see
|
||||
that it happened during the first layer, of the 2nd block, during the very first batch. Finally, the absolute largest
|
||||
input elements was ``6.27e+04`` and same for the output was ``inf``.
|
||||
|
||||
You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value was
|
||||
around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which renormalizes
|
||||
the weights, after it zeroed some of the elements, which pushes the absolute max value to more than 64K, and we get an
|
||||
overlow (``inf``).
|
||||
|
||||
As you can see it's the previous frames that we need to look into when the numbers start going into very large for fp16
|
||||
numbers.
|
||||
|
||||
Let's match the report to the code from ``models/t5/modeling_t5.py``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class T5DenseGatedGeluDense(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.gelu_act = ACT2FN["gelu_new"]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
Now it's easy to see the ``dropout`` call, and all the previous calls as well.
|
||||
|
||||
Since the detection is happening in a forward hook, these reports are printed immediately after each ``forward``
|
||||
returns.
|
||||
|
||||
Going back to the full report, to act on it and to fix the problem, we need to go a few frames up where the numbers
|
||||
started to go up and most likely switch to the ``fp32`` mode here, so that the numbers don't overflow when multiplied
|
||||
or summed up. Of course, there might be other solutions. For example, we could turn off ``amp`` temporarily if it's
|
||||
enabled, after moving the original ``forward`` into a helper wrapper, like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def _forward(self, hidden_states):
|
||||
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
import torch
|
||||
def forward(self, hidden_states):
|
||||
if torch.is_autocast_enabled():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self._forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
Since the automatic detector only reports on inputs and outputs of full frames, once you know where to look, you may
|
||||
want to analyse the intermediary stages of any specific ``forward`` function as well. In such a case you can use the
|
||||
``detect_overflow`` helper function to inject the detector where you want it, for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from debug_utils import detect_overflow
|
||||
|
||||
class T5LayerFF(nn.Module):
|
||||
[...]
|
||||
def forward(self, hidden_states):
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
detect_overflow(forwarded_states, "after layer_norm")
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
detect_overflow(forwarded_states, "after DenseReluDense")
|
||||
return hidden_states + self.dropout(forwarded_states)
|
||||
|
||||
You can see that we added 2 of these and now we track if ``inf`` or ``nan`` for ``forwarded_states`` was detected
|
||||
somewhere in between.
|
||||
|
||||
Actually, the detector already reports these because each of the calls in the example above is a `nn.Module``, but
|
||||
let's say if you had some local direct calculations this is how you'd do that.
|
||||
|
||||
Additionally, if you're instantiating the debugger in your own code, you can adjust the number of frames printed from
|
||||
its default, e.g.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from .debug_utils import DebugUnderflowOverflow
|
||||
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
|
||||
|
||||
Specific batch absolute mix and max value tracing
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The same debugging class can be used for per-batch tracing with the underflow/overflow detection feature turned off.
|
||||
|
||||
Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a given
|
||||
batch, and only do that for batches 1 and 3. Then you instantiate this class as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
|
||||
|
||||
And now full batches 1 and 3 will be traced using the same format as the underflow/overflow detector does.
|
||||
|
||||
Batches are 0-indexed.
|
||||
|
||||
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can fast-forward
|
||||
right to that area. Here is a sample truncated output for such configuration:
|
||||
|
||||
.. code-block::
|
||||
|
||||
*** Starting batch number=1 ***
|
||||
abs min abs max metadata
|
||||
shared Embedding
|
||||
1.01e-06 7.92e+02 weight
|
||||
0.00e+00 2.47e+04 input[0]
|
||||
5.36e-05 7.92e+02 output
|
||||
[...]
|
||||
decoder.dropout Dropout
|
||||
1.60e-07 2.27e+01 input[0]
|
||||
0.00e+00 2.52e+01 output
|
||||
decoder T5Stack
|
||||
not a tensor output
|
||||
lm_head Linear
|
||||
1.01e-06 7.92e+02 weight
|
||||
0.00e+00 1.11e+00 input[0]
|
||||
6.06e-02 8.39e+01 output
|
||||
T5ForConditionalGeneration
|
||||
not a tensor output
|
||||
|
||||
*** Starting batch number=3 ***
|
||||
abs min abs max metadata
|
||||
shared Embedding
|
||||
1.01e-06 7.92e+02 weight
|
||||
0.00e+00 2.78e+04 input[0]
|
||||
5.36e-05 7.92e+02 output
|
||||
[...]
|
||||
|
||||
Here you will get a huge number of frames dumped - as many as there were forward calls in your model, so it may or may
|
||||
not what you want, but sometimes it can be easier to use for debugging purposes than a normal debugger. For example, if
|
||||
a problem starts happening at batch number 150. So you can dump traces for batches 149 and 150 and compare where
|
||||
numbers started to diverge.
|
||||
|
||||
You can also specify the batch number after which to stop the training, with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
|
@ -405,6 +405,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
add_new_model
|
||||
fast_tokenizers
|
||||
testing
|
||||
debugging
|
||||
serialization
|
||||
|
||||
.. toctree::
|
||||
|
@ -1,4 +1,4 @@
|
||||
..
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
@ -46,3 +46,9 @@ Distributed Evaluation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HfArgumentParser
|
||||
|
||||
|
||||
Debug Utilities
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.debug_utils.DebugUnderflowOverflow
|
||||
|
326
src/transformers/debug_utils.py
Normal file
326
src/transformers/debug_utils.py
Normal file
@ -0,0 +1,326 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
|
||||
from .file_utils import ExplicitEnum, is_torch_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DebugUnderflowOverflow:
|
||||
"""
|
||||
This debug class helps detect and understand where the model starts getting very large or very small, and more
|
||||
importantly ``nan`` or ``inf`` weight and activation elements.
|
||||
|
||||
There are 2 working modes:
|
||||
|
||||
1. Underflow/overflow detection (default)
|
||||
2. Specific batch absolute min/max tracing without detection
|
||||
|
||||
Mode 1: Underflow/overflow detection
|
||||
|
||||
To activate the underflow/overflow detection, initialize the object with the model ::
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model)
|
||||
|
||||
then run the training as normal and if ``nan`` or ``inf`` gets detected in at least one of the weight, input or
|
||||
output elements this module will throw an exception and will print ``max_frames_to_save`` frames that lead to this
|
||||
event, each frame reporting
|
||||
|
||||
1. the fully qualified module name plus the class name whose ``forward`` was run
|
||||
2. the absolute min and max value of all elements for each module weights, and the inputs and output
|
||||
|
||||
For example, here is the header and the last few frames in detection report for ``google/mt5-small`` run in fp16 mixed precision ::
|
||||
|
||||
Detected inf/nan during batch_number=0
|
||||
Last 21 forward frames:
|
||||
abs min abs max metadata
|
||||
[...]
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
||||
2.17e-07 4.50e+00 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
2.68e-06 3.70e+01 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
||||
8.08e-07 2.66e+01 weight
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
1.27e-04 2.37e+02 output
|
||||
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
||||
1.01e-06 6.44e+00 weight
|
||||
0.00e+00 9.74e+03 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
||||
1.79e-06 4.65e+00 input[0]
|
||||
3.18e-04 6.27e+04 output
|
||||
encoder.block.2.layer.1.dropout Dropout
|
||||
3.18e-04 6.27e+04 input[0]
|
||||
0.00e+00 inf output
|
||||
|
||||
You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value
|
||||
was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which
|
||||
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
|
||||
64K, and we get an overlow.
|
||||
|
||||
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
|
||||
fp16 numbers.
|
||||
|
||||
The tracking is done in a forward hook, which gets invoked immediately after ``forward`` has completed.
|
||||
|
||||
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example ::
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
|
||||
|
||||
|
||||
|
||||
Mode 2. Specific batch absolute min/max tracing without detection
|
||||
|
||||
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
|
||||
|
||||
Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a
|
||||
given batch, and only do that for batches 1 and 3. Then you instantiate this class as ::
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
|
||||
|
||||
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
|
||||
|
||||
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
|
||||
fast-forward right to that area.
|
||||
|
||||
|
||||
|
||||
You can also specify the batch number after which to stop the training, with ::
|
||||
|
||||
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
|
||||
|
||||
This feature is mainly useful in the tracing mode, but you can use it for any more.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to debug.
|
||||
max_frames_to_save (:obj:`int`, `optional`, defaults to 21):
|
||||
How many frames back to record
|
||||
trace_batch_nums(:obj:`List[int]`, `optional`, defaults to ``[]``):
|
||||
Which batch numbers to trace (turns detection off)
|
||||
abort_after_batch_num (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
Whether to abort after a certain batch number has finished
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
|
||||
self.model = model
|
||||
self.trace_batch_nums = trace_batch_nums
|
||||
self.abort_after_batch_num = abort_after_batch_num
|
||||
|
||||
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
|
||||
self.frames = collections.deque([], max_frames_to_save)
|
||||
self.frame = []
|
||||
self.batch_number = 0
|
||||
self.total_calls = 0
|
||||
self.detected_overflow = False
|
||||
self.prefix = " "
|
||||
|
||||
self.analyse_model()
|
||||
|
||||
self.register_forward_hook()
|
||||
|
||||
def save_frame(self, frame=None):
|
||||
if frame is not None:
|
||||
self.expand_frame(frame)
|
||||
self.frames.append("\n".join(self.frame))
|
||||
self.frame = [] # start a new frame
|
||||
|
||||
def expand_frame(self, line):
|
||||
self.frame.append(line)
|
||||
|
||||
def trace_frames(self):
|
||||
print("\n".join(self.frames))
|
||||
self.frames = []
|
||||
|
||||
def reset_saved_frames(self):
|
||||
self.frames = []
|
||||
|
||||
def dump_saved_frames(self):
|
||||
print(f"\nDetected inf/nan during batch_number={self.batch_number}")
|
||||
print(f"Last {len(self.frames)} forward frames:")
|
||||
print(f"{'abs min':8} {'abs max':8} metadata")
|
||||
print("\n".join(self.frames))
|
||||
print("\n\n")
|
||||
self.frames = []
|
||||
|
||||
def analyse_model(self):
|
||||
# extract the fully qualified module names, to be able to report at run time. e.g.:
|
||||
# encoder.block.2.layer.0.SelfAttention.o
|
||||
#
|
||||
# for shared weights only the first shared module name will be registered
|
||||
self.module_names = {m: name for name, m in self.model.named_modules()}
|
||||
# self.longest_module_name = max(len(v) for v in self.module_names.values())
|
||||
|
||||
def analyse_variable(self, var, ctx):
|
||||
if torch.is_tensor(var):
|
||||
self.expand_frame(get_abs_min_max(var, ctx))
|
||||
if detect_overflow(var, ctx):
|
||||
self.detected_overflow = True
|
||||
elif var is None:
|
||||
self.expand_frame(f"{'None':>17} {ctx}")
|
||||
else:
|
||||
self.expand_frame(f"{'not a tensor':>17} {ctx}")
|
||||
|
||||
def batch_start_frame(self):
|
||||
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
|
||||
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
|
||||
|
||||
def batch_end_frame(self):
|
||||
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
|
||||
|
||||
def create_frame(self, module, input, output):
|
||||
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
|
||||
|
||||
# params
|
||||
for name, p in module.named_parameters(recurse=False):
|
||||
self.analyse_variable(p, name)
|
||||
|
||||
# inputs
|
||||
if isinstance(input, tuple):
|
||||
for i, x in enumerate(input):
|
||||
self.analyse_variable(x, f"input[{i}]")
|
||||
else:
|
||||
self.analyse_variable(input, "input")
|
||||
|
||||
# outputs
|
||||
if isinstance(output, tuple):
|
||||
for i, x in enumerate(output):
|
||||
# possibly a tuple of tuples
|
||||
if isinstance(x, tuple):
|
||||
for j, y in enumerate(x):
|
||||
self.analyse_variable(y, f"output[{i}][{j}]")
|
||||
else:
|
||||
self.analyse_variable(x, f"output[{i}]")
|
||||
else:
|
||||
self.analyse_variable(output, "output")
|
||||
|
||||
self.save_frame()
|
||||
|
||||
def register_forward_hook(self):
|
||||
self.model.apply(self._register_forward_hook)
|
||||
|
||||
def _register_forward_hook(self, module):
|
||||
module.register_forward_hook(self.forward_hook)
|
||||
|
||||
def forward_hook(self, module, input, output):
|
||||
# - input is a tuple of packed inputs (could be non-Tensors)
|
||||
# - output could be a Tensor or a tuple of Tensors and non-Tensors
|
||||
|
||||
last_frame_of_batch = False
|
||||
|
||||
trace_mode = True if self.batch_number in self.trace_batch_nums else False
|
||||
if trace_mode:
|
||||
self.reset_saved_frames()
|
||||
|
||||
if self.total_calls == 0:
|
||||
self.batch_start_frame()
|
||||
self.total_calls += 1
|
||||
|
||||
# count batch numbers - the very first forward hook of the batch will be called when the
|
||||
# batch completes - i.e. it gets called very last - we know this batch has finished
|
||||
if module == self.model:
|
||||
self.batch_number += 1
|
||||
last_frame_of_batch = True
|
||||
|
||||
self.create_frame(module, input, output)
|
||||
|
||||
# if last_frame_of_batch:
|
||||
# self.batch_end_frame()
|
||||
|
||||
if trace_mode:
|
||||
self.trace_frames()
|
||||
|
||||
if last_frame_of_batch:
|
||||
self.batch_start_frame()
|
||||
|
||||
if self.detected_overflow and not trace_mode:
|
||||
self.dump_saved_frames()
|
||||
|
||||
# now we can abort, as it's pointless to continue running
|
||||
raise ValueError(
|
||||
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
|
||||
"Please scroll up above this traceback to see the activation values prior to this event."
|
||||
)
|
||||
|
||||
# abort after certain batch if requested to do so
|
||||
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
|
||||
raise ValueError(
|
||||
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg"
|
||||
)
|
||||
|
||||
|
||||
def get_abs_min_max(var, ctx):
|
||||
abs_var = var.abs()
|
||||
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
|
||||
|
||||
|
||||
def detect_overflow(var, ctx):
|
||||
"""
|
||||
Report of the tensor contains any ``nan`` and ``inf`` entries.
|
||||
|
||||
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
|
||||
modified the variable in question.
|
||||
|
||||
The function contains a few other helper features that you can enable and tweak directly if you want to track
|
||||
various other things.
|
||||
|
||||
Args:
|
||||
var: tensor variable to check
|
||||
ctx: the message to print as a context
|
||||
|
||||
Return:
|
||||
True if ``inf`` or ``nan`` was detected, False otherwise
|
||||
"""
|
||||
detected = False
|
||||
if torch.isnan(var).any().item():
|
||||
detected = True
|
||||
print(f"{ctx} has nans")
|
||||
if torch.isinf(var).any().item():
|
||||
detected = True
|
||||
print(f"{ctx} has infs")
|
||||
|
||||
# if needed to monitor large elements can enable the following
|
||||
if 0: # and detected:
|
||||
n100 = var[torch.ge(var.abs(), 100)]
|
||||
if n100.numel() > 0:
|
||||
print(f"{ctx}: n100={n100.numel()}")
|
||||
n1000 = var[torch.ge(var.abs(), 1000)]
|
||||
if n1000.numel() > 0:
|
||||
print(f"{ctx}: n1000={n1000.numel()}")
|
||||
n10000 = var[torch.ge(var.abs(), 10000)]
|
||||
if n10000.numel() > 0:
|
||||
print(f"{ctx}: n10000={n10000.numel()}")
|
||||
|
||||
if 0:
|
||||
print(f"min={var.min():9.2e} max={var.max():9.2e}")
|
||||
|
||||
if 0:
|
||||
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
|
||||
|
||||
return detected
|
||||
|
||||
|
||||
class DebugOption(ExplicitEnum):
|
||||
UNDERFLOW_OVERFLOW = "underflow_overflow"
|
||||
TPU_METRICS_DEBUG = "tpu_metrics_debug"
|
@ -3154,7 +3154,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
|
||||
"""
|
||||
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
|
||||
Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
|
||||
corresponding model
|
||||
|
||||
Args:
|
||||
|
@ -59,6 +59,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
@ -1078,6 +1079,9 @@ class Trainer:
|
||||
num_train_epochs = int(args.num_train_epochs)
|
||||
num_update_steps_per_epoch = max_steps
|
||||
|
||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
if args.deepspeed:
|
||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
||||
@ -1301,7 +1305,7 @@ class Trainer:
|
||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||
|
||||
if args.tpu_metrics_debug or args.debug:
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
@ -1905,7 +1909,7 @@ class Trainer:
|
||||
|
||||
self.log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
|
@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .debug_utils import DebugOption
|
||||
from .file_utils import (
|
||||
cached_property,
|
||||
is_sagemaker_dp_enabled,
|
||||
@ -191,8 +192,6 @@ class TrainingArguments:
|
||||
Rank of the process during distributed training.
|
||||
tpu_num_cores (:obj:`int`, `optional`):
|
||||
When training on TPU, the number of TPU cores (automatically passed by launcher script).
|
||||
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
When training on TPU, whether to print debug metrics or not.
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||
or not.
|
||||
@ -274,6 +273,16 @@ class TrainingArguments:
|
||||
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
|
||||
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
|
||||
label_smoothing_factor + label_smoothing_factor/num_labels` respectively.
|
||||
debug (:obj:`str` or list of :class:`~transformers.debug_utils.DebugOption`, `optional`, defaults to :obj:`""`):
|
||||
Enable one or more debug features. This is an experimental feature.
|
||||
|
||||
Possible options are:
|
||||
|
||||
- :obj:`"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that
|
||||
led to the event
|
||||
- :obj:`"tpu_metrics_debug"`: print debug metrics on TPU
|
||||
|
||||
The options should be separated by whitespaces.
|
||||
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
|
||||
:class:`~transformers.AdamW`.
|
||||
@ -437,9 +446,18 @@ class TrainingArguments:
|
||||
)
|
||||
tpu_metrics_debug: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"},
|
||||
metadata={
|
||||
"help": "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
|
||||
},
|
||||
)
|
||||
debug: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "Whether or not to enable debug mode. Current options: "
|
||||
"`underflow_overflow` (Detect underflow and overflow in activations and weights), "
|
||||
"`tpu_metrics_debug` (print debug metrics on TPU)."
|
||||
},
|
||||
)
|
||||
debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"})
|
||||
|
||||
dataloader_drop_last: bool = field(
|
||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||
@ -631,6 +649,16 @@ class TrainingArguments:
|
||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.debug += " tpu_metrics_debug"
|
||||
self.tpu_metrics_debug = False
|
||||
if isinstance(self.debug, str):
|
||||
self.debug = [DebugOption(s) for s in self.debug.split()]
|
||||
|
||||
if self.deepspeed:
|
||||
# - must be run very last in arg parsing, since it will use a lot of these settings.
|
||||
# - must be run before the model is created.
|
||||
|
Loading…
Reference in New Issue
Block a user