From d5b82bb70c2e8c4b184a6f2a7d1c91d7fd156956 Mon Sep 17 00:00:00 2001 From: jacksukk Date: Fri, 15 Oct 2021 09:46:09 +0800 Subject: [PATCH] Fixed horizon_length for PPLM (#13886) * fixed horizon_length * fixed horizon_length * fix style --- examples/research_projects/pplm/run_pplm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/pplm/run_pplm.py b/examples/research_projects/pplm/run_pplm.py index 4be4f01fd4d..4872118433c 100644 --- a/examples/research_projects/pplm/run_pplm.py +++ b/examples/research_projects/pplm/run_pplm.py @@ -181,7 +181,14 @@ def perturb_past( for _ in range(horizon_length): inputs_embeds = torch.matmul(curr_probs, wte.weight.data) lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds) - curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"] + curr_all_logits, curr_unpert_past, curr_all_hidden = ( + lm_output["logits"], + lm_output["past_key_values"], + lm_output["hidden_states"], + ) + curr_logits = curr_all_logits[:, -1, :] + curr_probs = nn.functional.softmax(curr_logits, dim=-1) + curr_probs = torch.unsqueeze(curr_probs, dim=1) curr_hidden = curr_all_hidden[-1] new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)