diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 57bed3890f6..5e094278792 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -231,7 +231,8 @@ def perturb_past( prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length)) - label = torch.tensor([class_label], device=device, + label = torch.tensor(prediction.shape[0] * [class_label], + device=device, dtype=torch.long) discrim_loss = ce_loss(prediction, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) @@ -508,11 +509,12 @@ def generate_text_pplm( gm_scale=0.9, kl_scale=0.01, ): - output_so_far = ( - torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) - if context - else None - ) + output_so_far = None + if context: + context_t = torch.tensor(context, device=device, dtype=torch.long) + while len(context_t.shape) < 2: + context_t = context_t.unsqueeze(0) + output_so_far = context_t # collect one hot vectors for bags of words one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,