Add an error message that fires when Reformer is not in training mode, but one runs .backward() (#11117)

This commit is contained in:
Yusuke Mori 2021-04-21 07:23:37 +09:00 committed by GitHub
parent f1b938fda8
commit 95dab34d55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1512,6 +1512,10 @@ class ReformerLayer(nn.Module):
# Implementation of RevNet (see Fig. 6 in https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)
# This code is heavily inspired by https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
assert (
self.training
), "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the model into training mode."
with torch.enable_grad():
next_attn_output.requires_grad = True