generate_text_pplm now works with batch_size > 1

This commit is contained in:
Piero Molino 2019-12-01 15:48:33 -08:00 committed by Julien Chaumond
parent 893d0d64fe
commit a59fdd1627

View File

@ -231,7 +231,8 @@ def perturb_past(
prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length))
label = torch.tensor([class_label], device=device,
label = torch.tensor(prediction.shape[0] * [class_label],
device=device,
dtype=torch.long)
discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
@ -508,11 +509,12 @@ def generate_text_pplm(
gm_scale=0.9,
kl_scale=0.01,
):
output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
if context
else None
)
output_so_far = None
if context:
context_t = torch.tensor(context, device=device, dtype=torch.long)
while len(context_t.shape) < 2:
context_t = context_t.unsqueeze(0)
output_so_far = context_t
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,