From 83b1e6ac9e81cbb053ee272a4a4fcb0b6fac06ab Mon Sep 17 00:00:00 2001 From: Rosanne Liu Date: Sun, 3 Nov 2019 04:51:57 +0000 Subject: [PATCH] fix the loss backward issue (cherry picked from commit 566468cc984c6ec7e10dfc62b5b4191781a99cd2) --- examples/run_pplm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 30853c68c33..59ae8a9299d 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -36,6 +36,7 @@ from tqdm import trange from transformers import GPT2Tokenizer from transformers.file_utils import cached_path from transformers.modeling_gpt2 import GPT2LMHeadModel +from IPython import embed PPLM_BOW = 1 PPLM_DISCRIM = 2 @@ -246,8 +247,8 @@ def perturb_past( inputs_embeds=inputs_embeds ) # get expected hidden states - unpert_hidden = curr_all_hidden[1] - accumulated_hidden += torch.sum(unpert_hidden, dim=1) + unpert_hidden = curr_all_hidden[-1] + accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach() prediction = classifier( accumulated_hidden / (curr_length + 1 + horizon_length) @@ -257,7 +258,7 @@ def perturb_past( discrim_loss += ce_loss(prediction, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) - if kl_scale > 0.0: + if kl_scale >= 0.0: unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = ( unpert_probs + SMALL_CONST * @@ -270,7 +271,7 @@ def perturb_past( torch.FloatTensor ).cuda().detach() corrected_probs = probs + correction.detach() - kl_loss += kl_scale * ( + kl_loss = kl_scale * ( (corrected_probs * (corrected_probs / unpert_probs).log()).sum() ) print(' kl_loss', (kl_loss).data.cpu().numpy()) @@ -280,7 +281,7 @@ def perturb_past( print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) # compute gradients - loss.backward(retain_graph=True) + loss.backward() # calculate gradient norms if grad_norms is not None and loss_type == PPLM_BOW: