First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.

This commit is contained in:
piero 2019-11-27 16:32:45 -08:00 committed by Julien Chaumond
parent 821de121e8
commit 4f2164e40e

View File

@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
logits) logits)
def perturb_past(past, model, prev, args, classifier, good_index=None, def perturb_past(
stepsize=0.01, vocab_size=50257, past,
original_probs=None, accumulated_hidden=None, true_past=None, model,
grad_norms=None): prev,
window_length = args.window_length unpert_past=None,
gm_scale, kl_scale = args.gm_scale, args.kl_scale unpert_logits=None,
one_hot_vectors = [] accumulated_hidden=None,
for good_list in good_index: grad_norms=None,
good_list = list(filter(lambda x: len(x) <= 1, good_list)) stepsize=0.01,
good_list = torch.tensor(good_list).cuda() classifier=None,
num_good = good_list.shape[0] label_class=None,
one_hot_good = torch.zeros(num_good, vocab_size).cuda() one_hot_bows_vectors=None,
one_hot_good.scatter_(1, good_list, 1) loss_type=0,
one_hot_vectors.append(one_hot_good) num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
# one_hot_bows_vectors = []
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# good_list = torch.tensor(good_list).cuda()
# num_good = good_list.shape[0]
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
# one_hot_good.scatter_(1, good_list, 1)
# one_hot_bows_vectors.append(one_hot_good)
# Generate inital perturbed past # Generate inital perturbed past
past_perturb_orig = [ past_perturb_orig = [
@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
if accumulated_hidden is None: if accumulated_hidden is None:
accumulated_hidden = 0 accumulated_hidden = 0
if args.decay: if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
1:] 1:]
else: else:
@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
window_mask = torch.ones_like(past[0]).cuda() window_mask = torch.ones_like(past[0]).cuda()
loss_per_iter = [] loss_per_iter = []
for i in range(args.num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
probabs = F.softmax(logits, dim=-1) probabs = F.softmax(logits, dim=-1)
loss = 0.0 loss = 0.0
loss_list = [] loss_list = []
if args.loss_type == 1 or args.loss_type == 3: if loss_type == 1 or loss_type == 3:
for one_hot_good in one_hot_vectors: for one_hot_good in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good)) good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits loss_word = good_logits
loss_word = torch.sum(loss_word) loss_word = torch.sum(loss_word)
@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
loss_list.append(loss_word) loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy()) print(" pplm_bow_loss:", loss.data.cpu().numpy())
if args.loss_type == 2 or args.loss_type == 3: if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = true_past new_true_past = unpert_past
for i in range(args.horizon_length): for i in range(horizon_length):
future_probabs = F.softmax(logits, dim=-1) # Get softmax future_probabs = F.softmax(logits, dim=-1) # Get softmax
future_probabs = torch.unsqueeze(future_probabs, dim=1) future_probabs = torch.unsqueeze(future_probabs, dim=1)
@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
future_hidden, dim=1) future_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / ( predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + args.horizon_length)) current_length + 1 + horizon_length))
label = torch.tensor([args.label_class], device='cuda', label = torch.tensor([label_class], device='cuda',
dtype=torch.long) dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label) discrim_loss = ce_loss(predicted_sentiment, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
p = (F.softmax(original_probs[:, -1, :], dim=-1)) p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type( p = p + SmallConst * (p <= SmallConst).type(
torch.FloatTensor).cuda().detach() torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type( correction = SmallConst * (probabs <= SmallConst).type(
@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
loss.backward() loss.backward()
if grad_norms is not None and args.loss_type == 1: if grad_norms is not None and loss_type == 1:
grad_norms = [ grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in for index, p_ in
@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
grad = [ grad = [
-stepsize * (p_.grad * window_mask / grad_norms[ -stepsize * (p_.grad * window_mask / grad_norms[
index] ** args.gamma).data.cpu().numpy() index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)] for index, p_ in enumerate(past_perturb)]
past_perturb_orig = list(map(add, grad, past_perturb_orig)) past_perturb_orig = list(map(add, grad, past_perturb_orig))
@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
return one_hot_bows_vectors return one_hot_bows_vectors
def latent_perturb(model, args, context=None, sample=True, device='cuda'): def full_text_generation(
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
args.discrim, discrim,
args.label_class, label_class,
device device
) )
@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
# good_list = list(filter(lambda x: len(x) <= 1, good_list)) # good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index = [] bow_indices = []
actual_words = None actual_words = None
if args.bag_of_words: if bag_of_words:
good_index = get_bag_of_words_indices(args.bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in good_index: for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list)) good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list] good_list]
if args.bag_of_words and classifier: if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
args.loss_type = PPLM_BOW_DISCRIM loss_type = PPLM_BOW_DISCRIM
elif args.bag_of_words: elif bag_of_words:
args.loss_type = PPLM_BOW loss_type = PPLM_BOW
print("Using PPLM-BoW") print("Using PPLM-BoW")
elif classifier is not None: elif classifier is not None:
args.loss_type = PPLM_DISCRIM loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim") print("Using PPLM-Discrim")
else: else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
original, _, _ = sample_from_hidden(model=model, args=args, context=context, original, _, _ = generate_text_pplm(
device=device, model=model,
perturb=False, good_index=good_index, context=context,
classifier=classifier) device=device,
length=length,
perturb=False
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
perturbed_list = [] perturbed_list = []
discrim_loss_list = [] discrim_loss_list = []
loss_in_time_list = [] loss_in_time_list = []
for i in range(args.num_samples): for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model, perturbed, discrim_loss, loss_in_time = generate_text_pplm(
args=args, model=model,
context=context, context=context,
device=device, device=device,
perturb=True, sample=sample,
good_index=good_index, perturb=True,
classifier=classifier) bow_indices=bow_indices,
classifier=classifier,
label_class=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
perturbed_list.append(perturbed) perturbed_list.append(perturbed)
if classifier is not None: if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy()) discrim_loss_list.append(discrim_loss.data.cpu().numpy())
@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
def sample_from_hidden(model, args, classifier, context=None, past=None,
device='cuda', def generate_text_pplm(
sample=True, perturb=True, good_index=None): model,
context=None,
past=None,
device="cuda",
sample=True,
perturb=True,
classifier=None,
label_class=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze( output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None 0) if context else None
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None grad_norms = None
loss_in_time = [] loss_in_time = []
for i in trange(args.length, ascii=True): for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word # Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token # Note that GPT takes 2 inputs: past + current-token
@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call # Piero modified model call
_, past, _ = model(output[:, :-1]) _, past, _ = model(output[:, :-1])
original_probs, true_past, unpert_all_hidden = model(output) unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1] true_hidden = unpert_all_hidden[-1]
else: else:
@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# true_hidden = model.hidden_states # true_hidden = model.hidden_states
# Piero modified model call # Piero modified model call
original_probs, true_past, unpert_all_hidden = model(output) unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1] true_hidden = unpert_all_hidden[-1]
# Modify the past if necessary # Modify the past if necessary
if i >= args.grad_length: if i >= grad_length:
current_stepsize = args.stepsize * 0 current_stepsize = stepsize * 0
else: else:
current_stepsize = args.stepsize current_stepsize = stepsize
if not perturb or args.num_iterations == 0: if not perturb or num_iterations == 0:
perturbed_past = past perturbed_past = past
else: else:
@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
accumulated_hidden = true_hidden[:, :-1, :] accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past, perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
model, past,
prev, model,
args, prev,
good_index=good_index, unpert_past=unpert_past,
stepsize=current_stepsize, unpert_logits=unpert_logits,
original_probs=original_probs, accumulated_hidden=accumulated_hidden,
true_past=true_past, grad_norms=grad_norms,
accumulated_hidden=accumulated_hidden, stepsize=current_stepsize,
classifier=classifier, classifier=classifier,
grad_norms=grad_norms) label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter) loss_in_time.append(loss_per_iter)
# Piero modified model call # Piero modified model call
@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
label = torch.tensor([args.label_class], device='cuda', label = torch.tensor([label_class], device='cuda',
dtype=torch.long) dtype=torch.long)
true_discrim_loss = ce_loss(predicted_sentiment, label) true_discrim_loss = ce_loss(predicted_sentiment, label)
print("true discrim loss", true_discrim_loss.data.cpu().numpy()) print("true discrim loss", true_discrim_loss.data.cpu().numpy())
@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call # Piero modified model call
# hidden = model.hidden_states # update hidden # hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden) # logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / args.temperature # + SmallConst logits = logits[:, -1, :] / temperature # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst # logits = top_k_filter(logits, k=args.top_k) # + SmallConst
@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if perturb: if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
original_probs = F.softmax(original_probs[:, -1, :], dim=-1) unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1) # likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0])) # print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale = args.gm_scale
log_probs = ((log_probs ** gm_scale) * ( log_probs = ((log_probs ** gm_scale) * (
original_probs ** (1 - gm_scale))) # + SmallConst unpert_logits ** (1 - gm_scale))) # + SmallConst
log_probs = top_k_filter(log_probs, k=args.top_k, log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst probs=True) # + SmallConst
if torch.sum(log_probs) <= 1: if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs) log_probs = log_probs / torch.sum(log_probs)
else: else:
logits = top_k_filter(logits, k=args.top_k) # + SmallConst logits = top_k_filter(logits, k=top_k) # + SmallConst
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
if sample: if sample:
@ -673,16 +767,16 @@ def run_model():
collect_gen = dict() collect_gen = dict()
current_index = 0 current_index = 0
for out in seq: for tokenized_cond_text in seq:
text = TOKENIZER.decode(out) text = TOKENIZER.decode(tokenized_cond_text)
print("=" * 40 + " Prefix of sentence " + "=" * 40) print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text) print(text)
print("=" * 80) print("=" * 80)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb( out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
model=model, args=args, context=out, model=model, context=tokenized_cond_text, device=device, **vars(args)
device=device) )
text_whole = TOKENIZER.decode(out1.tolist()[0]) text_whole = TOKENIZER.decode(out1.tolist()[0])
@ -712,20 +806,20 @@ def run_model():
import colorama import colorama
text_whole = '' text_whole = ''
for out in output_tokens: for tokenized_cond_text in output_tokens:
if out in keyword_tokens: if tokenized_cond_text in keyword_tokens:
text_whole += '%s%s%s' % ( text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([out]), colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
colorama.Style.RESET_ALL) colorama.Style.RESET_ALL)
else: else:
text_whole += TOKENIZER.decode([out]) text_whole += TOKENIZER.decode([tokenized_cond_text])
else: else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
print(text_whole) print(text_whole)
print("=" * 80) print("=" * 80)
collect_gen[current_index] = [out, out_perturb, out1] collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
current_index = current_index + 1 current_index = current_index + 1