[debug utils] activation/weights underflow/overflow detector (#11274)

* sync

* add activation overflow debug utility

* cleanup

* document detect_overflow

* import torch

* add deprecation warning

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* convert to rst, add note

* add class

* fix docs

* improve the doc

* rework to dump a lot more info about each frame

* complete expansion

* cleanup

* format

* cleanup

* doesn't have to be transformers

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* wrap long line

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-04-30 11:15:46 -07:00 committed by GitHub
parent 804c2974d5
commit 282f3ac3ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 668 additions and 8 deletions

295
docs/source/debugging.rst Normal file
View File

@ -0,0 +1,295 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Debugging
=======================================================================================================================
Underflow and Overflow Detection
-----------------------------------------------------------------------------------------------------------------------
.. note::
This feature is currently available for PyTorch-only.
.. note::
This feature can be used with any ``nn.Module``-based model
If you start getting ``loss=NaN`` or the model inhibits some other abnormal behavior due to ``inf`` or ``nan`` in
activations or weights one needs to discover where the first underflow or overflow happens and what led to it. Luckily
you can accomplish that easily by activating a special module that will do the detection automatically.
If you're using :class:`~transformers.Trainer`, you just need to add:
.. code-block:: bash
--debug underflow_overflow
to the normal command line arguments, or pass ``debug="underflow_overflow"`` when creating the
:class:`~transformers.TrainingArguments` object.
If you're using your own training loop or another Trainer you can accomplish the same with:
.. code-block:: python
from .debug_utils import DebugUnderflowOverflow
debug_overflow = DebugUnderflowOverflow(model)
:class:`~transformers.debug_utils.DebugUnderflowOverflow` inserts hooks into the model that immediately after each
forward call will test input and output variables and also the corresponding module's weights. As soon as ``inf`` or
``nan`` is detected in at least one element of the activations or weights, the program will assert and print a report
like this (this was caught with ``google/mt5-small`` under fp16 mixed precision):
.. code-block::
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
encoder.block.1.layer.1.DenseReluDense.dropout Dropout
0.00e+00 2.57e+02 input[0]
0.00e+00 2.85e+02 output
[...]
encoder.block.2.layer.0 T5LayerSelfAttention
6.78e-04 3.15e+03 input[0]
2.65e-04 3.42e+03 output[0]
None output[1]
2.25e-01 1.00e+04 output[2]
encoder.block.2.layer.1.layer_norm T5LayerNorm
8.69e-02 4.18e-01 weight
2.65e-04 3.42e+03 input[0]
1.79e-06 4.65e+00 output
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.dropout Dropout
0.00e+00 8.76e+03 input[0]
0.00e+00 9.74e+03 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
The example output has been trimmed in the middle for brevity.
The second column shows the value of the absolute largest element, so if you have a closer look at the last few frames,
the inputs and outputs were in the range of ``1e4``. So when this training was done under fp16 mixed precision the very
last step overflowed (since under ``fp16`` the largest number before ``inf`` is ``64e3``). To avoid overflows under
``fp16`` the activations must remain way below ``1e4``, because ``1e4 * 1e4 = 1e8`` so any matrix multiplication with
large activations is going to lead to a numerical overflow condition.
At the very start of the trace you can discover at which batch number the problem occurred (here ``Detected inf/nan
during batch_number=0`` means the problem occurred on the first batch).
Each reported frame starts by declaring the fully qualified entry for the corresponding module this frame is reporting
for. If we look just at this frame:
.. code-block::
encoder.block.2.layer.1.layer_norm T5LayerNorm
8.69e-02 4.18e-01 weight
2.65e-04 3.42e+03 input[0]
1.79e-06 4.65e+00 output
Here, ``encoder.block.2.layer.1.layer_norm`` indicates that it was a layer norm for the first layer, of the second
block of the encoder. And the specific calls of the ``forward`` is ``T5LayerNorm``.
Let's look at the last few frames of that report:
.. code-block::
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
[...]
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
The last frame reports for ``Dropout.forward`` function with the first entry for the only input and the second for the
only output. You can see that it was called from an attribute ``dropout`` inside ``DenseReluDense`` class. We can see
that it happened during the first layer, of the 2nd block, during the very first batch. Finally, the absolute largest
input elements was ``6.27e+04`` and same for the output was ``inf``.
You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value was
around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which renormalizes
the weights, after it zeroed some of the elements, which pushes the absolute max value to more than 64K, and we get an
overlow (``inf``).
As you can see it's the previous frames that we need to look into when the numbers start going into very large for fp16
numbers.
Let's match the report to the code from ``models/t5/modeling_t5.py``:
.. code-block:: python
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
Now it's easy to see the ``dropout`` call, and all the previous calls as well.
Since the detection is happening in a forward hook, these reports are printed immediately after each ``forward``
returns.
Going back to the full report, to act on it and to fix the problem, we need to go a few frames up where the numbers
started to go up and most likely switch to the ``fp32`` mode here, so that the numbers don't overflow when multiplied
or summed up. Of course, there might be other solutions. For example, we could turn off ``amp`` temporarily if it's
enabled, after moving the original ``forward`` into a helper wrapper, like so:
.. code-block:: python
def _forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
import torch
def forward(self, hidden_states):
if torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(hidden_states)
else:
return self._forward(hidden_states)
Since the automatic detector only reports on inputs and outputs of full frames, once you know where to look, you may
want to analyse the intermediary stages of any specific ``forward`` function as well. In such a case you can use the
``detect_overflow`` helper function to inject the detector where you want it, for example:
.. code-block:: python
from debug_utils import detect_overflow
class T5LayerFF(nn.Module):
[...]
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
detect_overflow(forwarded_states, "after layer_norm")
forwarded_states = self.DenseReluDense(forwarded_states)
detect_overflow(forwarded_states, "after DenseReluDense")
return hidden_states + self.dropout(forwarded_states)
You can see that we added 2 of these and now we track if ``inf`` or ``nan`` for ``forwarded_states`` was detected
somewhere in between.
Actually, the detector already reports these because each of the calls in the example above is a `nn.Module``, but
let's say if you had some local direct calculations this is how you'd do that.
Additionally, if you're instantiating the debugger in your own code, you can adjust the number of frames printed from
its default, e.g.:
.. code-block:: python
from .debug_utils import DebugUnderflowOverflow
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
Specific batch absolute mix and max value tracing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The same debugging class can be used for per-batch tracing with the underflow/overflow detection feature turned off.
Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a given
batch, and only do that for batches 1 and 3. Then you instantiate this class as:
.. code-block:: python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
And now full batches 1 and 3 will be traced using the same format as the underflow/overflow detector does.
Batches are 0-indexed.
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can fast-forward
right to that area. Here is a sample truncated output for such configuration:
.. code-block::
*** Starting batch number=1 ***
abs min abs max metadata
shared Embedding
1.01e-06 7.92e+02 weight
0.00e+00 2.47e+04 input[0]
5.36e-05 7.92e+02 output
[...]
decoder.dropout Dropout
1.60e-07 2.27e+01 input[0]
0.00e+00 2.52e+01 output
decoder T5Stack
not a tensor output
lm_head Linear
1.01e-06 7.92e+02 weight
0.00e+00 1.11e+00 input[0]
6.06e-02 8.39e+01 output
T5ForConditionalGeneration
not a tensor output
*** Starting batch number=3 ***
abs min abs max metadata
shared Embedding
1.01e-06 7.92e+02 weight
0.00e+00 2.78e+04 input[0]
5.36e-05 7.92e+02 output
[...]
Here you will get a huge number of frames dumped - as many as there were forward calls in your model, so it may or may
not what you want, but sometimes it can be easier to use for debugging purposes than a normal debugger. For example, if
a problem starts happening at batch number 150. So you can dump traces for batches 149 and 150 and compare where
numbers started to diverge.
You can also specify the batch number after which to stop the training, with:
.. code-block:: python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)

View File

@ -405,6 +405,7 @@ Flax), PyTorch, and/or TensorFlow.
add_new_model
fast_tokenizers
testing
debugging
serialization
.. toctree::

View File

@ -1,4 +1,4 @@
..
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
@ -46,3 +46,9 @@ Distributed Evaluation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.HfArgumentParser
Debug Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.debug_utils.DebugUnderflowOverflow

View File

@ -0,0 +1,326 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from .file_utils import ExplicitEnum, is_torch_available
from .utils import logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class DebugUnderflowOverflow:
"""
This debug class helps detect and understand where the model starts getting very large or very small, and more
importantly ``nan`` or ``inf`` weight and activation elements.
There are 2 working modes:
1. Underflow/overflow detection (default)
2. Specific batch absolute min/max tracing without detection
Mode 1: Underflow/overflow detection
To activate the underflow/overflow detection, initialize the object with the model ::
debug_overflow = DebugUnderflowOverflow(model)
then run the training as normal and if ``nan`` or ``inf`` gets detected in at least one of the weight, input or
output elements this module will throw an exception and will print ``max_frames_to_save`` frames that lead to this
event, each frame reporting
1. the fully qualified module name plus the class name whose ``forward`` was run
2. the absolute min and max value of all elements for each module weights, and the inputs and output
For example, here is the header and the last few frames in detection report for ``google/mt5-small`` run in fp16 mixed precision ::
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
[...]
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value
was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
64K, and we get an overlow.
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
fp16 numbers.
The tracking is done in a forward hook, which gets invoked immediately after ``forward`` has completed.
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example ::
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
Mode 2. Specific batch absolute min/max tracing without detection
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a
given batch, and only do that for batches 1 and 3. Then you instantiate this class as ::
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
fast-forward right to that area.
You can also specify the batch number after which to stop the training, with ::
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
This feature is mainly useful in the tracing mode, but you can use it for any more.
Args:
model (:obj:`nn.Module`):
The model to debug.
max_frames_to_save (:obj:`int`, `optional`, defaults to 21):
How many frames back to record
trace_batch_nums(:obj:`List[int]`, `optional`, defaults to ``[]``):
Which batch numbers to trace (turns detection off)
abort_after_batch_num (:obj:`int`, `optional`, defaults to :obj:`None`):
Whether to abort after a certain batch number has finished
"""
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
self.model = model
self.trace_batch_nums = trace_batch_nums
self.abort_after_batch_num = abort_after_batch_num
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
self.frames = collections.deque([], max_frames_to_save)
self.frame = []
self.batch_number = 0
self.total_calls = 0
self.detected_overflow = False
self.prefix = " "
self.analyse_model()
self.register_forward_hook()
def save_frame(self, frame=None):
if frame is not None:
self.expand_frame(frame)
self.frames.append("\n".join(self.frame))
self.frame = [] # start a new frame
def expand_frame(self, line):
self.frame.append(line)
def trace_frames(self):
print("\n".join(self.frames))
self.frames = []
def reset_saved_frames(self):
self.frames = []
def dump_saved_frames(self):
print(f"\nDetected inf/nan during batch_number={self.batch_number}")
print(f"Last {len(self.frames)} forward frames:")
print(f"{'abs min':8} {'abs max':8} metadata")
print("\n".join(self.frames))
print("\n\n")
self.frames = []
def analyse_model(self):
# extract the fully qualified module names, to be able to report at run time. e.g.:
# encoder.block.2.layer.0.SelfAttention.o
#
# for shared weights only the first shared module name will be registered
self.module_names = {m: name for name, m in self.model.named_modules()}
# self.longest_module_name = max(len(v) for v in self.module_names.values())
def analyse_variable(self, var, ctx):
if torch.is_tensor(var):
self.expand_frame(get_abs_min_max(var, ctx))
if detect_overflow(var, ctx):
self.detected_overflow = True
elif var is None:
self.expand_frame(f"{'None':>17} {ctx}")
else:
self.expand_frame(f"{'not a tensor':>17} {ctx}")
def batch_start_frame(self):
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
def batch_end_frame(self):
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
def create_frame(self, module, input, output):
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
# params
for name, p in module.named_parameters(recurse=False):
self.analyse_variable(p, name)
# inputs
if isinstance(input, tuple):
for i, x in enumerate(input):
self.analyse_variable(x, f"input[{i}]")
else:
self.analyse_variable(input, "input")
# outputs
if isinstance(output, tuple):
for i, x in enumerate(output):
# possibly a tuple of tuples
if isinstance(x, tuple):
for j, y in enumerate(x):
self.analyse_variable(y, f"output[{i}][{j}]")
else:
self.analyse_variable(x, f"output[{i}]")
else:
self.analyse_variable(output, "output")
self.save_frame()
def register_forward_hook(self):
self.model.apply(self._register_forward_hook)
def _register_forward_hook(self, module):
module.register_forward_hook(self.forward_hook)
def forward_hook(self, module, input, output):
# - input is a tuple of packed inputs (could be non-Tensors)
# - output could be a Tensor or a tuple of Tensors and non-Tensors
last_frame_of_batch = False
trace_mode = True if self.batch_number in self.trace_batch_nums else False
if trace_mode:
self.reset_saved_frames()
if self.total_calls == 0:
self.batch_start_frame()
self.total_calls += 1
# count batch numbers - the very first forward hook of the batch will be called when the
# batch completes - i.e. it gets called very last - we know this batch has finished
if module == self.model:
self.batch_number += 1
last_frame_of_batch = True
self.create_frame(module, input, output)
# if last_frame_of_batch:
# self.batch_end_frame()
if trace_mode:
self.trace_frames()
if last_frame_of_batch:
self.batch_start_frame()
if self.detected_overflow and not trace_mode:
self.dump_saved_frames()
# now we can abort, as it's pointless to continue running
raise ValueError(
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
"Please scroll up above this traceback to see the activation values prior to this event."
)
# abort after certain batch if requested to do so
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
raise ValueError(
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg"
)
def get_abs_min_max(var, ctx):
abs_var = var.abs()
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
def detect_overflow(var, ctx):
"""
Report of the tensor contains any ``nan`` and ``inf`` entries.
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
modified the variable in question.
The function contains a few other helper features that you can enable and tweak directly if you want to track
various other things.
Args:
var: tensor variable to check
ctx: the message to print as a context
Return:
True if ``inf`` or ``nan`` was detected, False otherwise
"""
detected = False
if torch.isnan(var).any().item():
detected = True
print(f"{ctx} has nans")
if torch.isinf(var).any().item():
detected = True
print(f"{ctx} has infs")
# if needed to monitor large elements can enable the following
if 0: # and detected:
n100 = var[torch.ge(var.abs(), 100)]
if n100.numel() > 0:
print(f"{ctx}: n100={n100.numel()}")
n1000 = var[torch.ge(var.abs(), 1000)]
if n1000.numel() > 0:
print(f"{ctx}: n1000={n1000.numel()}")
n10000 = var[torch.ge(var.abs(), 10000)]
if n10000.numel() > 0:
print(f"{ctx}: n10000={n10000.numel()}")
if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e}")
if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
return detected
class DebugOption(ExplicitEnum):
UNDERFLOW_OVERFLOW = "underflow_overflow"
TPU_METRICS_DEBUG = "tpu_metrics_debug"

View File

@ -3154,7 +3154,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
corresponding model
Args:

View File

@ -59,6 +59,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
@ -1078,6 +1079,9 @@ class Trainer:
num_train_epochs = int(args.num_train_epochs)
num_update_steps_per_epoch = max_steps
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
@ -1301,7 +1305,7 @@ class Trainer:
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if args.tpu_metrics_debug or args.debug:
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
@ -1905,7 +1909,7 @@ class Trainer:
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

View File

@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption
from .file_utils import (
cached_property,
is_sagemaker_dp_enabled,
@ -191,8 +192,6 @@ class TrainingArguments:
Rank of the process during distributed training.
tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the number of TPU cores (automatically passed by launcher script).
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
When training on TPU, whether to print debug metrics or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
@ -274,6 +273,16 @@ class TrainingArguments:
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
label_smoothing_factor + label_smoothing_factor/num_labels` respectively.
debug (:obj:`str` or list of :class:`~transformers.debug_utils.DebugOption`, `optional`, defaults to :obj:`""`):
Enable one or more debug features. This is an experimental feature.
Possible options are:
- :obj:`"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that
led to the event
- :obj:`"tpu_metrics_debug"`: print debug metrics on TPU
The options should be separated by whitespaces.
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
:class:`~transformers.AdamW`.
@ -437,9 +446,18 @@ class TrainingArguments:
)
tpu_metrics_debug: bool = field(
default=False,
metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"},
metadata={
"help": "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
},
)
debug: str = field(
default="",
metadata={
"help": "Whether or not to enable debug mode. Current options: "
"`underflow_overflow` (Detect underflow and overflow in activations and weights), "
"`tpu_metrics_debug` (print debug metrics on TPU)."
},
)
debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"})
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
@ -631,6 +649,16 @@ class TrainingArguments:
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
FutureWarning,
)
self.debug += " tpu_metrics_debug"
self.tpu_metrics_debug = False
if isinstance(self.debug, str):
self.debug = [DebugOption(s) for s in self.debug.split()]
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.