mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
generate_text_pplm now works with batch_size > 1
This commit is contained in:
parent
893d0d64fe
commit
a59fdd1627
@ -231,7 +231,8 @@ def perturb_past(
|
||||
prediction = classifier(new_accumulated_hidden /
|
||||
(curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor([class_label], device=device,
|
||||
label = torch.tensor(prediction.shape[0] * [class_label],
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
@ -508,11 +509,12 @@ def generate_text_pplm(
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
):
|
||||
output_so_far = (
|
||||
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
|
||||
if context
|
||||
else None
|
||||
)
|
||||
output_so_far = None
|
||||
if context:
|
||||
context_t = torch.tensor(context, device=device, dtype=torch.long)
|
||||
while len(context_t.shape) < 2:
|
||||
context_t = context_t.unsqueeze(0)
|
||||
output_so_far = context_t
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
||||
|
Loading…
Reference in New Issue
Block a user