mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add an error message that fires when Reformer is not in training mode, but one runs .backward() (#11117)
This commit is contained in:
parent
f1b938fda8
commit
95dab34d55
@ -1512,6 +1512,10 @@ class ReformerLayer(nn.Module):
|
||||
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
|
||||
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
|
||||
|
||||
assert (
|
||||
self.training
|
||||
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
|
||||
|
||||
with torch.enable_grad():
|
||||
next_attn_output.requires_grad = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user