mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.
This commit is contained in:
parent
821de121e8
commit
4f2164e40e
@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
|
|||||||
logits)
|
logits)
|
||||||
|
|
||||||
|
|
||||||
def perturb_past(past, model, prev, args, classifier, good_index=None,
|
def perturb_past(
|
||||||
stepsize=0.01, vocab_size=50257,
|
past,
|
||||||
original_probs=None, accumulated_hidden=None, true_past=None,
|
model,
|
||||||
grad_norms=None):
|
prev,
|
||||||
window_length = args.window_length
|
unpert_past=None,
|
||||||
gm_scale, kl_scale = args.gm_scale, args.kl_scale
|
unpert_logits=None,
|
||||||
one_hot_vectors = []
|
accumulated_hidden=None,
|
||||||
for good_list in good_index:
|
grad_norms=None,
|
||||||
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
stepsize=0.01,
|
||||||
good_list = torch.tensor(good_list).cuda()
|
classifier=None,
|
||||||
num_good = good_list.shape[0]
|
label_class=None,
|
||||||
one_hot_good = torch.zeros(num_good, vocab_size).cuda()
|
one_hot_bows_vectors=None,
|
||||||
one_hot_good.scatter_(1, good_list, 1)
|
loss_type=0,
|
||||||
one_hot_vectors.append(one_hot_good)
|
num_iterations=3,
|
||||||
|
kl_scale=0.01,
|
||||||
|
window_length=0,
|
||||||
|
horizon_length=1,
|
||||||
|
decay=False,
|
||||||
|
gamma=1.5,
|
||||||
|
):
|
||||||
|
|
||||||
|
#def perturb_past(past, model, prev, classifier, good_index=None,
|
||||||
|
# stepsize=0.01, vocab_size=50257,
|
||||||
|
# original_probs=None, accumulated_hidden=None, true_past=None,
|
||||||
|
# grad_norms=None):
|
||||||
|
|
||||||
|
# one_hot_bows_vectors = []
|
||||||
|
# for good_list in good_index:
|
||||||
|
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||||
|
# good_list = torch.tensor(good_list).cuda()
|
||||||
|
# num_good = good_list.shape[0]
|
||||||
|
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
|
||||||
|
# one_hot_good.scatter_(1, good_list, 1)
|
||||||
|
# one_hot_bows_vectors.append(one_hot_good)
|
||||||
|
|
||||||
# Generate inital perturbed past
|
# Generate inital perturbed past
|
||||||
past_perturb_orig = [
|
past_perturb_orig = [
|
||||||
@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
if accumulated_hidden is None:
|
if accumulated_hidden is None:
|
||||||
accumulated_hidden = 0
|
accumulated_hidden = 0
|
||||||
|
|
||||||
if args.decay:
|
if decay:
|
||||||
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
|
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
|
||||||
1:]
|
1:]
|
||||||
else:
|
else:
|
||||||
@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
window_mask = torch.ones_like(past[0]).cuda()
|
window_mask = torch.ones_like(past[0]).cuda()
|
||||||
|
|
||||||
loss_per_iter = []
|
loss_per_iter = []
|
||||||
for i in range(args.num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
||||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
||||||
@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
probabs = F.softmax(logits, dim=-1)
|
probabs = F.softmax(logits, dim=-1)
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
loss_list = []
|
loss_list = []
|
||||||
if args.loss_type == 1 or args.loss_type == 3:
|
if loss_type == 1 or loss_type == 3:
|
||||||
for one_hot_good in one_hot_vectors:
|
for one_hot_good in one_hot_bows_vectors:
|
||||||
good_logits = torch.mm(probabs, torch.t(one_hot_good))
|
good_logits = torch.mm(probabs, torch.t(one_hot_good))
|
||||||
loss_word = good_logits
|
loss_word = good_logits
|
||||||
loss_word = torch.sum(loss_word)
|
loss_word = torch.sum(loss_word)
|
||||||
@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
loss_list.append(loss_word)
|
loss_list.append(loss_word)
|
||||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||||
|
|
||||||
if args.loss_type == 2 or args.loss_type == 3:
|
if loss_type == 2 or loss_type == 3:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
new_true_past = true_past
|
new_true_past = unpert_past
|
||||||
for i in range(args.horizon_length):
|
for i in range(horizon_length):
|
||||||
future_probabs = F.softmax(logits, dim=-1) # Get softmax
|
future_probabs = F.softmax(logits, dim=-1) # Get softmax
|
||||||
future_probabs = torch.unsqueeze(future_probabs, dim=1)
|
future_probabs = torch.unsqueeze(future_probabs, dim=1)
|
||||||
|
|
||||||
@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
future_hidden, dim=1)
|
future_hidden, dim=1)
|
||||||
|
|
||||||
predicted_sentiment = classifier(new_accumulated_hidden / (
|
predicted_sentiment = classifier(new_accumulated_hidden / (
|
||||||
current_length + 1 + args.horizon_length))
|
current_length + 1 + horizon_length))
|
||||||
|
|
||||||
label = torch.tensor([args.label_class], device='cuda',
|
label = torch.tensor([label_class], device='cuda',
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
discrim_loss = ce_loss(predicted_sentiment, label)
|
discrim_loss = ce_loss(predicted_sentiment, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
|
|
||||||
kl_loss = 0.0
|
kl_loss = 0.0
|
||||||
if kl_scale > 0.0:
|
if kl_scale > 0.0:
|
||||||
p = (F.softmax(original_probs[:, -1, :], dim=-1))
|
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
||||||
p = p + SmallConst * (p <= SmallConst).type(
|
p = p + SmallConst * (p <= SmallConst).type(
|
||||||
torch.FloatTensor).cuda().detach()
|
torch.FloatTensor).cuda().detach()
|
||||||
correction = SmallConst * (probabs <= SmallConst).type(
|
correction = SmallConst * (probabs <= SmallConst).type(
|
||||||
@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if grad_norms is not None and args.loss_type == 1:
|
if grad_norms is not None and loss_type == 1:
|
||||||
grad_norms = [
|
grad_norms = [
|
||||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||||
for index, p_ in
|
for index, p_ in
|
||||||
@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
|||||||
|
|
||||||
grad = [
|
grad = [
|
||||||
-stepsize * (p_.grad * window_mask / grad_norms[
|
-stepsize * (p_.grad * window_mask / grad_norms[
|
||||||
index] ** args.gamma).data.cpu().numpy()
|
index] ** gamma).data.cpu().numpy()
|
||||||
for index, p_ in enumerate(past_perturb)]
|
for index, p_ in enumerate(past_perturb)]
|
||||||
past_perturb_orig = list(map(add, grad, past_perturb_orig))
|
past_perturb_orig = list(map(add, grad, past_perturb_orig))
|
||||||
|
|
||||||
@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
|
|||||||
return one_hot_bows_vectors
|
return one_hot_bows_vectors
|
||||||
|
|
||||||
|
|
||||||
def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
def full_text_generation(
|
||||||
|
model,
|
||||||
|
context=None,
|
||||||
|
num_samples=1,
|
||||||
|
device="cuda",
|
||||||
|
sample=True,
|
||||||
|
discrim=None,
|
||||||
|
label_class=None,
|
||||||
|
bag_of_words=None,
|
||||||
|
length=100,
|
||||||
|
grad_length=10000,
|
||||||
|
stepsize=0.02,
|
||||||
|
num_iterations=3,
|
||||||
|
temperature=1.0,
|
||||||
|
gm_scale=0.9,
|
||||||
|
kl_scale=0.01,
|
||||||
|
top_k=10,
|
||||||
|
window_length=0,
|
||||||
|
horizon_length=1,
|
||||||
|
decay=False,
|
||||||
|
gamma=1.5,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
classifier, class_id = get_classifier(
|
classifier, class_id = get_classifier(
|
||||||
args.discrim,
|
discrim,
|
||||||
args.label_class,
|
label_class,
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
|||||||
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||||
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
||||||
|
|
||||||
good_index = []
|
bow_indices = []
|
||||||
actual_words = None
|
actual_words = None
|
||||||
if args.bag_of_words:
|
if bag_of_words:
|
||||||
good_index = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
||||||
|
|
||||||
for good_list in good_index:
|
for good_list in bow_indices:
|
||||||
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||||
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
|
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
|
||||||
good_list]
|
good_list]
|
||||||
|
|
||||||
if args.bag_of_words and classifier:
|
if bag_of_words and classifier:
|
||||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||||
args.loss_type = PPLM_BOW_DISCRIM
|
loss_type = PPLM_BOW_DISCRIM
|
||||||
|
|
||||||
elif args.bag_of_words:
|
elif bag_of_words:
|
||||||
args.loss_type = PPLM_BOW
|
loss_type = PPLM_BOW
|
||||||
print("Using PPLM-BoW")
|
print("Using PPLM-BoW")
|
||||||
|
|
||||||
elif classifier is not None:
|
elif classifier is not None:
|
||||||
args.loss_type = PPLM_DISCRIM
|
loss_type = PPLM_DISCRIM
|
||||||
print("Using PPLM-Discrim")
|
print("Using PPLM-Discrim")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
|
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
|
||||||
|
|
||||||
original, _, _ = sample_from_hidden(model=model, args=args, context=context,
|
original, _, _ = generate_text_pplm(
|
||||||
device=device,
|
model=model,
|
||||||
perturb=False, good_index=good_index,
|
context=context,
|
||||||
classifier=classifier)
|
device=device,
|
||||||
|
length=length,
|
||||||
|
perturb=False
|
||||||
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
perturbed_list = []
|
perturbed_list = []
|
||||||
discrim_loss_list = []
|
discrim_loss_list = []
|
||||||
loss_in_time_list = []
|
loss_in_time_list = []
|
||||||
|
|
||||||
for i in range(args.num_samples):
|
for i in range(num_samples):
|
||||||
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model,
|
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
|
||||||
args=args,
|
model=model,
|
||||||
context=context,
|
context=context,
|
||||||
device=device,
|
device=device,
|
||||||
perturb=True,
|
sample=sample,
|
||||||
good_index=good_index,
|
perturb=True,
|
||||||
classifier=classifier)
|
bow_indices=bow_indices,
|
||||||
|
classifier=classifier,
|
||||||
|
label_class=class_id,
|
||||||
|
loss_type=loss_type,
|
||||||
|
length=length,
|
||||||
|
grad_length=grad_length,
|
||||||
|
stepsize=stepsize,
|
||||||
|
num_iterations=num_iterations,
|
||||||
|
temperature=temperature,
|
||||||
|
gm_scale=gm_scale,
|
||||||
|
kl_scale=kl_scale,
|
||||||
|
top_k=top_k,
|
||||||
|
window_length=window_length,
|
||||||
|
horizon_length=horizon_length,
|
||||||
|
decay=decay,
|
||||||
|
gamma=gamma,
|
||||||
|
)
|
||||||
perturbed_list.append(perturbed)
|
perturbed_list.append(perturbed)
|
||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
|
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
|
||||||
@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
|||||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
|
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
|
||||||
|
|
||||||
|
|
||||||
def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|
||||||
device='cuda',
|
def generate_text_pplm(
|
||||||
sample=True, perturb=True, good_index=None):
|
model,
|
||||||
|
context=None,
|
||||||
|
past=None,
|
||||||
|
device="cuda",
|
||||||
|
sample=True,
|
||||||
|
perturb=True,
|
||||||
|
classifier=None,
|
||||||
|
label_class=None,
|
||||||
|
bow_indices=None,
|
||||||
|
loss_type=0,
|
||||||
|
length=100,
|
||||||
|
grad_length=10000,
|
||||||
|
stepsize=0.02,
|
||||||
|
num_iterations=3,
|
||||||
|
temperature=1.0,
|
||||||
|
gm_scale=0.9,
|
||||||
|
kl_scale=0.01,
|
||||||
|
top_k=10,
|
||||||
|
window_length=0,
|
||||||
|
horizon_length=1,
|
||||||
|
decay=False,
|
||||||
|
gamma=1.5,
|
||||||
|
):
|
||||||
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
|
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
|
||||||
0) if context else None
|
0) if context else None
|
||||||
|
|
||||||
|
# collect one hot vectors for bags of words
|
||||||
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
loss_in_time = []
|
loss_in_time = []
|
||||||
for i in trange(args.length, ascii=True):
|
for i in trange(length, ascii=True):
|
||||||
|
|
||||||
# Get past/probs for current output, except for last word
|
# Get past/probs for current output, except for last word
|
||||||
# Note that GPT takes 2 inputs: past + current-token
|
# Note that GPT takes 2 inputs: past + current-token
|
||||||
@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
|
|
||||||
# Piero modified model call
|
# Piero modified model call
|
||||||
_, past, _ = model(output[:, :-1])
|
_, past, _ = model(output[:, :-1])
|
||||||
original_probs, true_past, unpert_all_hidden = model(output)
|
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||||
true_hidden = unpert_all_hidden[-1]
|
true_hidden = unpert_all_hidden[-1]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
# true_hidden = model.hidden_states
|
# true_hidden = model.hidden_states
|
||||||
|
|
||||||
# Piero modified model call
|
# Piero modified model call
|
||||||
original_probs, true_past, unpert_all_hidden = model(output)
|
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||||
true_hidden = unpert_all_hidden[-1]
|
true_hidden = unpert_all_hidden[-1]
|
||||||
|
|
||||||
# Modify the past if necessary
|
# Modify the past if necessary
|
||||||
|
|
||||||
if i >= args.grad_length:
|
if i >= grad_length:
|
||||||
current_stepsize = args.stepsize * 0
|
current_stepsize = stepsize * 0
|
||||||
else:
|
else:
|
||||||
current_stepsize = args.stepsize
|
current_stepsize = stepsize
|
||||||
|
|
||||||
if not perturb or args.num_iterations == 0:
|
if not perturb or num_iterations == 0:
|
||||||
perturbed_past = past
|
perturbed_past = past
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
accumulated_hidden = true_hidden[:, :-1, :]
|
accumulated_hidden = true_hidden[:, :-1, :]
|
||||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||||
|
|
||||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past,
|
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
|
||||||
model,
|
past,
|
||||||
prev,
|
model,
|
||||||
args,
|
prev,
|
||||||
good_index=good_index,
|
unpert_past=unpert_past,
|
||||||
stepsize=current_stepsize,
|
unpert_logits=unpert_logits,
|
||||||
original_probs=original_probs,
|
accumulated_hidden=accumulated_hidden,
|
||||||
true_past=true_past,
|
grad_norms=grad_norms,
|
||||||
accumulated_hidden=accumulated_hidden,
|
stepsize=current_stepsize,
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
grad_norms=grad_norms)
|
label_class=label_class,
|
||||||
|
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||||
|
loss_type=loss_type,
|
||||||
|
num_iterations=num_iterations,
|
||||||
|
kl_scale=kl_scale,
|
||||||
|
window_length=window_length,
|
||||||
|
horizon_length=horizon_length,
|
||||||
|
decay=decay,
|
||||||
|
gamma=gamma,
|
||||||
|
)
|
||||||
loss_in_time.append(loss_per_iter)
|
loss_in_time.append(loss_per_iter)
|
||||||
|
|
||||||
# Piero modified model call
|
# Piero modified model call
|
||||||
@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
|
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
|
||||||
label = torch.tensor([args.label_class], device='cuda',
|
label = torch.tensor([label_class], device='cuda',
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
true_discrim_loss = ce_loss(predicted_sentiment, label)
|
true_discrim_loss = ce_loss(predicted_sentiment, label)
|
||||||
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
|
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
|
||||||
@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
# Piero modified model call
|
# Piero modified model call
|
||||||
# hidden = model.hidden_states # update hidden
|
# hidden = model.hidden_states # update hidden
|
||||||
# logits = model.forward_hidden(hidden)
|
# logits = model.forward_hidden(hidden)
|
||||||
logits = logits[:, -1, :] / args.temperature # + SmallConst
|
logits = logits[:, -1, :] / temperature # + SmallConst
|
||||||
|
|
||||||
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
||||||
|
|
||||||
@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
|||||||
if perturb:
|
if perturb:
|
||||||
|
|
||||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
|
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
|
||||||
original_probs = F.softmax(original_probs[:, -1, :], dim=-1)
|
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
||||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||||
|
|
||||||
gm_scale = args.gm_scale
|
|
||||||
log_probs = ((log_probs ** gm_scale) * (
|
log_probs = ((log_probs ** gm_scale) * (
|
||||||
original_probs ** (1 - gm_scale))) # + SmallConst
|
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
||||||
|
|
||||||
log_probs = top_k_filter(log_probs, k=args.top_k,
|
log_probs = top_k_filter(log_probs, k=top_k,
|
||||||
probs=True) # + SmallConst
|
probs=True) # + SmallConst
|
||||||
|
|
||||||
if torch.sum(log_probs) <= 1:
|
if torch.sum(log_probs) <= 1:
|
||||||
log_probs = log_probs / torch.sum(log_probs)
|
log_probs = log_probs / torch.sum(log_probs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
logits = top_k_filter(logits, k=top_k) # + SmallConst
|
||||||
log_probs = F.softmax(logits, dim=-1)
|
log_probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
if sample:
|
if sample:
|
||||||
@ -673,16 +767,16 @@ def run_model():
|
|||||||
|
|
||||||
collect_gen = dict()
|
collect_gen = dict()
|
||||||
current_index = 0
|
current_index = 0
|
||||||
for out in seq:
|
for tokenized_cond_text in seq:
|
||||||
|
|
||||||
text = TOKENIZER.decode(out)
|
text = TOKENIZER.decode(tokenized_cond_text)
|
||||||
print("=" * 40 + " Prefix of sentence " + "=" * 40)
|
print("=" * 40 + " Prefix of sentence " + "=" * 40)
|
||||||
print(text)
|
print(text)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb(
|
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
|
||||||
model=model, args=args, context=out,
|
model=model, context=tokenized_cond_text, device=device, **vars(args)
|
||||||
device=device)
|
)
|
||||||
|
|
||||||
text_whole = TOKENIZER.decode(out1.tolist()[0])
|
text_whole = TOKENIZER.decode(out1.tolist()[0])
|
||||||
|
|
||||||
@ -712,20 +806,20 @@ def run_model():
|
|||||||
import colorama
|
import colorama
|
||||||
|
|
||||||
text_whole = ''
|
text_whole = ''
|
||||||
for out in output_tokens:
|
for tokenized_cond_text in output_tokens:
|
||||||
if out in keyword_tokens:
|
if tokenized_cond_text in keyword_tokens:
|
||||||
text_whole += '%s%s%s' % (
|
text_whole += '%s%s%s' % (
|
||||||
colorama.Fore.GREEN, TOKENIZER.decode([out]),
|
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
|
||||||
colorama.Style.RESET_ALL)
|
colorama.Style.RESET_ALL)
|
||||||
else:
|
else:
|
||||||
text_whole += TOKENIZER.decode([out])
|
text_whole += TOKENIZER.decode([tokenized_cond_text])
|
||||||
else:
|
else:
|
||||||
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
|
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
|
||||||
|
|
||||||
print(text_whole)
|
print(text_whole)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
collect_gen[current_index] = [out, out_perturb, out1]
|
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
|
||||||
|
|
||||||
current_index = current_index + 1
|
current_index = current_index + 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user