First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.

This commit is contained in:
piero 2019-11-27 16:32:45 -08:00 committed by Julien Chaumond
parent 821de121e8
commit 4f2164e40e

View File

@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
logits)
def perturb_past(past, model, prev, args, classifier, good_index=None,
stepsize=0.01, vocab_size=50257,
original_probs=None, accumulated_hidden=None, true_past=None,
grad_norms=None):
window_length = args.window_length
gm_scale, kl_scale = args.gm_scale, args.kl_scale
one_hot_vectors = []
for good_list in good_index:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
good_list = torch.tensor(good_list).cuda()
num_good = good_list.shape[0]
one_hot_good = torch.zeros(num_good, vocab_size).cuda()
one_hot_good.scatter_(1, good_list, 1)
one_hot_vectors.append(one_hot_good)
def perturb_past(
past,
model,
prev,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
classifier=None,
label_class=None,
one_hot_bows_vectors=None,
loss_type=0,
num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
# one_hot_bows_vectors = []
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# good_list = torch.tensor(good_list).cuda()
# num_good = good_list.shape[0]
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
# one_hot_good.scatter_(1, good_list, 1)
# one_hot_bows_vectors.append(one_hot_good)
# Generate inital perturbed past
past_perturb_orig = [
@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
if accumulated_hidden is None:
accumulated_hidden = 0
if args.decay:
if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
1:]
else:
@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
window_mask = torch.ones_like(past[0]).cuda()
loss_per_iter = []
for i in range(args.num_iterations):
for i in range(num_iterations):
print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
probabs = F.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
if args.loss_type == 1 or args.loss_type == 3:
for one_hot_good in one_hot_vectors:
if loss_type == 1 or loss_type == 3:
for one_hot_good in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits
loss_word = torch.sum(loss_word)
@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if args.loss_type == 2 or args.loss_type == 3:
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = true_past
for i in range(args.horizon_length):
new_true_past = unpert_past
for i in range(horizon_length):
future_probabs = F.softmax(logits, dim=-1) # Get softmax
future_probabs = torch.unsqueeze(future_probabs, dim=1)
@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
future_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + args.horizon_length))
current_length + 1 + horizon_length))
label = torch.tensor([args.label_class], device='cuda',
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
kl_loss = 0.0
if kl_scale > 0.0:
p = (F.softmax(original_probs[:, -1, :], dim=-1))
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type(
torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type(
@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
loss.backward()
if grad_norms is not None and args.loss_type == 1:
if grad_norms is not None and loss_type == 1:
grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in
@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
grad = [
-stepsize * (p_.grad * window_mask / grad_norms[
index] ** args.gamma).data.cpu().numpy()
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)]
past_perturb_orig = list(map(add, grad, past_perturb_orig))
@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
return one_hot_bows_vectors
def latent_perturb(model, args, context=None, sample=True, device='cuda'):
def full_text_generation(
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier(
args.discrim,
args.label_class,
discrim,
label_class,
device
)
@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index = []
bow_indices = []
actual_words = None
if args.bag_of_words:
good_index = get_bag_of_words_indices(args.bag_of_words.split(";"))
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in good_index:
for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list]
if args.bag_of_words and classifier:
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
args.loss_type = PPLM_BOW_DISCRIM
loss_type = PPLM_BOW_DISCRIM
elif args.bag_of_words:
args.loss_type = PPLM_BOW
elif bag_of_words:
loss_type = PPLM_BOW
print("Using PPLM-BoW")
elif classifier is not None:
args.loss_type = PPLM_DISCRIM
loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim")
else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
original, _, _ = sample_from_hidden(model=model, args=args, context=context,
original, _, _ = generate_text_pplm(
model=model,
context=context,
device=device,
perturb=False, good_index=good_index,
classifier=classifier)
length=length,
perturb=False
)
torch.cuda.empty_cache()
perturbed_list = []
discrim_loss_list = []
loss_in_time_list = []
for i in range(args.num_samples):
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model,
args=args,
for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
context=context,
device=device,
sample=sample,
perturb=True,
good_index=good_index,
classifier=classifier)
bow_indices=bow_indices,
classifier=classifier,
label_class=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
perturbed_list.append(perturbed)
if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
def sample_from_hidden(model, args, classifier, context=None, past=None,
device='cuda',
sample=True, perturb=True, good_index=None):
def generate_text_pplm(
model,
context=None,
past=None,
device="cuda",
sample=True,
perturb=True,
classifier=None,
label_class=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
loss_in_time = []
for i in trange(args.length, ascii=True):
for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call
_, past, _ = model(output[:, :-1])
original_probs, true_past, unpert_all_hidden = model(output)
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
else:
@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# true_hidden = model.hidden_states
# Piero modified model call
original_probs, true_past, unpert_all_hidden = model(output)
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
# Modify the past if necessary
if i >= args.grad_length:
current_stepsize = args.stepsize * 0
if i >= grad_length:
current_stepsize = stepsize * 0
else:
current_stepsize = args.stepsize
current_stepsize = stepsize
if not perturb or args.num_iterations == 0:
if not perturb or num_iterations == 0:
perturbed_past = past
else:
@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past,
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
past,
model,
prev,
args,
good_index=good_index,
stepsize=current_stepsize,
original_probs=original_probs,
true_past=true_past,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
grad_norms=grad_norms)
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter)
# Piero modified model call
@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
label = torch.tensor([args.label_class], device='cuda',
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
true_discrim_loss = ce_loss(predicted_sentiment, label)
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / args.temperature # + SmallConst
logits = logits[:, -1, :] / temperature # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
original_probs = F.softmax(original_probs[:, -1, :], dim=-1)
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale = args.gm_scale
log_probs = ((log_probs ** gm_scale) * (
original_probs ** (1 - gm_scale))) # + SmallConst
unpert_logits ** (1 - gm_scale))) # + SmallConst
log_probs = top_k_filter(log_probs, k=args.top_k,
log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst
if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs)
else:
logits = top_k_filter(logits, k=args.top_k) # + SmallConst
logits = top_k_filter(logits, k=top_k) # + SmallConst
log_probs = F.softmax(logits, dim=-1)
if sample:
@ -673,16 +767,16 @@ def run_model():
collect_gen = dict()
current_index = 0
for out in seq:
for tokenized_cond_text in seq:
text = TOKENIZER.decode(out)
text = TOKENIZER.decode(tokenized_cond_text)
print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text)
print("=" * 80)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb(
model=model, args=args, context=out,
device=device)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
)
text_whole = TOKENIZER.decode(out1.tolist()[0])
@ -712,20 +806,20 @@ def run_model():
import colorama
text_whole = ''
for out in output_tokens:
if out in keyword_tokens:
for tokenized_cond_text in output_tokens:
if tokenized_cond_text in keyword_tokens:
text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([out]),
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
colorama.Style.RESET_ALL)
else:
text_whole += TOKENIZER.decode([out])
text_whole += TOKENIZER.decode([tokenized_cond_text])
else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
print(text_whole)
print("=" * 80)
collect_gen[current_index] = [out, out_perturb, out1]
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
current_index = current_index + 1