mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixed horizon_length for PPLM (#13886)
* fixed horizon_length * fixed horizon_length * fix style
This commit is contained in:
parent
5b317f7ea4
commit
d5b82bb70c
@ -181,7 +181,14 @@ def perturb_past(
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
|
||||
curr_all_logits, curr_unpert_past, curr_all_hidden = (
|
||||
lm_output["logits"],
|
||||
lm_output["past_key_values"],
|
||||
lm_output["hidden_states"],
|
||||
)
|
||||
curr_logits = curr_all_logits[:, -1, :]
|
||||
curr_probs = nn.functional.softmax(curr_logits, dim=-1)
|
||||
curr_probs = torch.unsqueeze(curr_probs, dim=1)
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user