mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.
This commit is contained in:
parent
821de121e8
commit
4f2164e40e
@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
|
||||
logits)
|
||||
|
||||
|
||||
def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
stepsize=0.01, vocab_size=50257,
|
||||
original_probs=None, accumulated_hidden=None, true_past=None,
|
||||
grad_norms=None):
|
||||
window_length = args.window_length
|
||||
gm_scale, kl_scale = args.gm_scale, args.kl_scale
|
||||
one_hot_vectors = []
|
||||
for good_list in good_index:
|
||||
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
good_list = torch.tensor(good_list).cuda()
|
||||
num_good = good_list.shape[0]
|
||||
one_hot_good = torch.zeros(num_good, vocab_size).cuda()
|
||||
one_hot_good.scatter_(1, good_list, 1)
|
||||
one_hot_vectors.append(one_hot_good)
|
||||
def perturb_past(
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
classifier=None,
|
||||
label_class=None,
|
||||
one_hot_bows_vectors=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
kl_scale=0.01,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
):
|
||||
|
||||
#def perturb_past(past, model, prev, classifier, good_index=None,
|
||||
# stepsize=0.01, vocab_size=50257,
|
||||
# original_probs=None, accumulated_hidden=None, true_past=None,
|
||||
# grad_norms=None):
|
||||
|
||||
# one_hot_bows_vectors = []
|
||||
# for good_list in good_index:
|
||||
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
# good_list = torch.tensor(good_list).cuda()
|
||||
# num_good = good_list.shape[0]
|
||||
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
|
||||
# one_hot_good.scatter_(1, good_list, 1)
|
||||
# one_hot_bows_vectors.append(one_hot_good)
|
||||
|
||||
# Generate inital perturbed past
|
||||
past_perturb_orig = [
|
||||
@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
if accumulated_hidden is None:
|
||||
accumulated_hidden = 0
|
||||
|
||||
if args.decay:
|
||||
if decay:
|
||||
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
|
||||
1:]
|
||||
else:
|
||||
@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
window_mask = torch.ones_like(past[0]).cuda()
|
||||
|
||||
loss_per_iter = []
|
||||
for i in range(args.num_iterations):
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
||||
@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
probabs = F.softmax(logits, dim=-1)
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
if args.loss_type == 1 or args.loss_type == 3:
|
||||
for one_hot_good in one_hot_vectors:
|
||||
if loss_type == 1 or loss_type == 3:
|
||||
for one_hot_good in one_hot_bows_vectors:
|
||||
good_logits = torch.mm(probabs, torch.t(one_hot_good))
|
||||
loss_word = good_logits
|
||||
loss_word = torch.sum(loss_word)
|
||||
@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
loss_list.append(loss_word)
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if args.loss_type == 2 or args.loss_type == 3:
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
new_true_past = true_past
|
||||
for i in range(args.horizon_length):
|
||||
new_true_past = unpert_past
|
||||
for i in range(horizon_length):
|
||||
future_probabs = F.softmax(logits, dim=-1) # Get softmax
|
||||
future_probabs = torch.unsqueeze(future_probabs, dim=1)
|
||||
|
||||
@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
future_hidden, dim=1)
|
||||
|
||||
predicted_sentiment = classifier(new_accumulated_hidden / (
|
||||
current_length + 1 + args.horizon_length))
|
||||
current_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor([args.label_class], device='cuda',
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
dtype=torch.long)
|
||||
discrim_loss = ce_loss(predicted_sentiment, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
p = (F.softmax(original_probs[:, -1, :], dim=-1))
|
||||
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
||||
p = p + SmallConst * (p <= SmallConst).type(
|
||||
torch.FloatTensor).cuda().detach()
|
||||
correction = SmallConst * (probabs <= SmallConst).type(
|
||||
@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
loss.backward()
|
||||
if grad_norms is not None and args.loss_type == 1:
|
||||
if grad_norms is not None and loss_type == 1:
|
||||
grad_norms = [
|
||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||
for index, p_ in
|
||||
@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
|
||||
grad = [
|
||||
-stepsize * (p_.grad * window_mask / grad_norms[
|
||||
index] ** args.gamma).data.cpu().numpy()
|
||||
index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(past_perturb)]
|
||||
past_perturb_orig = list(map(add, grad, past_perturb_orig))
|
||||
|
||||
@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
|
||||
return one_hot_bows_vectors
|
||||
|
||||
|
||||
def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
||||
def full_text_generation(
|
||||
model,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
sample=True,
|
||||
discrim=None,
|
||||
label_class=None,
|
||||
bag_of_words=None,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(
|
||||
args.discrim,
|
||||
args.label_class,
|
||||
discrim,
|
||||
label_class,
|
||||
device
|
||||
)
|
||||
|
||||
@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
||||
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
||||
|
||||
good_index = []
|
||||
bow_indices = []
|
||||
actual_words = None
|
||||
if args.bag_of_words:
|
||||
good_index = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
||||
|
||||
for good_list in good_index:
|
||||
for good_list in bow_indices:
|
||||
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
|
||||
good_list]
|
||||
|
||||
if args.bag_of_words and classifier:
|
||||
if bag_of_words and classifier:
|
||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||
args.loss_type = PPLM_BOW_DISCRIM
|
||||
loss_type = PPLM_BOW_DISCRIM
|
||||
|
||||
elif args.bag_of_words:
|
||||
args.loss_type = PPLM_BOW
|
||||
elif bag_of_words:
|
||||
loss_type = PPLM_BOW
|
||||
print("Using PPLM-BoW")
|
||||
|
||||
elif classifier is not None:
|
||||
args.loss_type = PPLM_DISCRIM
|
||||
loss_type = PPLM_DISCRIM
|
||||
print("Using PPLM-Discrim")
|
||||
|
||||
else:
|
||||
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
|
||||
|
||||
original, _, _ = sample_from_hidden(model=model, args=args, context=context,
|
||||
original, _, _ = generate_text_pplm(
|
||||
model=model,
|
||||
context=context,
|
||||
device=device,
|
||||
perturb=False, good_index=good_index,
|
||||
classifier=classifier)
|
||||
length=length,
|
||||
perturb=False
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
perturbed_list = []
|
||||
discrim_loss_list = []
|
||||
loss_in_time_list = []
|
||||
|
||||
for i in range(args.num_samples):
|
||||
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model,
|
||||
args=args,
|
||||
for i in range(num_samples):
|
||||
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
|
||||
model=model,
|
||||
context=context,
|
||||
device=device,
|
||||
sample=sample,
|
||||
perturb=True,
|
||||
good_index=good_index,
|
||||
classifier=classifier)
|
||||
bow_indices=bow_indices,
|
||||
classifier=classifier,
|
||||
label_class=class_id,
|
||||
loss_type=loss_type,
|
||||
length=length,
|
||||
grad_length=grad_length,
|
||||
stepsize=stepsize,
|
||||
num_iterations=num_iterations,
|
||||
temperature=temperature,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
top_k=top_k,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
perturbed_list.append(perturbed)
|
||||
if classifier is not None:
|
||||
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
|
||||
@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
|
||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
|
||||
|
||||
|
||||
def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
device='cuda',
|
||||
sample=True, perturb=True, good_index=None):
|
||||
|
||||
def generate_text_pplm(
|
||||
model,
|
||||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
sample=True,
|
||||
perturb=True,
|
||||
classifier=None,
|
||||
label_class=None,
|
||||
bow_indices=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
):
|
||||
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
|
||||
0) if context else None
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||
|
||||
grad_norms = None
|
||||
loss_in_time = []
|
||||
for i in trange(args.length, ascii=True):
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current-token
|
||||
@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
|
||||
# Piero modified model call
|
||||
_, past, _ = model(output[:, :-1])
|
||||
original_probs, true_past, unpert_all_hidden = model(output)
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||
true_hidden = unpert_all_hidden[-1]
|
||||
|
||||
else:
|
||||
@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
# true_hidden = model.hidden_states
|
||||
|
||||
# Piero modified model call
|
||||
original_probs, true_past, unpert_all_hidden = model(output)
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||
true_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# Modify the past if necessary
|
||||
|
||||
if i >= args.grad_length:
|
||||
current_stepsize = args.stepsize * 0
|
||||
if i >= grad_length:
|
||||
current_stepsize = stepsize * 0
|
||||
else:
|
||||
current_stepsize = args.stepsize
|
||||
current_stepsize = stepsize
|
||||
|
||||
if not perturb or args.num_iterations == 0:
|
||||
if not perturb or num_iterations == 0:
|
||||
perturbed_past = past
|
||||
|
||||
else:
|
||||
@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
accumulated_hidden = true_hidden[:, :-1, :]
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past,
|
||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
args,
|
||||
good_index=good_index,
|
||||
stepsize=current_stepsize,
|
||||
original_probs=original_probs,
|
||||
true_past=true_past,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
grad_norms=grad_norms)
|
||||
label_class=label_class,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
loss_in_time.append(loss_per_iter)
|
||||
|
||||
# Piero modified model call
|
||||
@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
|
||||
label = torch.tensor([args.label_class], device='cuda',
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
dtype=torch.long)
|
||||
true_discrim_loss = ce_loss(predicted_sentiment, label)
|
||||
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
|
||||
@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
# Piero modified model call
|
||||
# hidden = model.hidden_states # update hidden
|
||||
# logits = model.forward_hidden(hidden)
|
||||
logits = logits[:, -1, :] / args.temperature # + SmallConst
|
||||
logits = logits[:, -1, :] / temperature # + SmallConst
|
||||
|
||||
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
||||
|
||||
@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
|
||||
if perturb:
|
||||
|
||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
|
||||
original_probs = F.softmax(original_probs[:, -1, :], dim=-1)
|
||||
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
|
||||
gm_scale = args.gm_scale
|
||||
log_probs = ((log_probs ** gm_scale) * (
|
||||
original_probs ** (1 - gm_scale))) # + SmallConst
|
||||
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
||||
|
||||
log_probs = top_k_filter(log_probs, k=args.top_k,
|
||||
log_probs = top_k_filter(log_probs, k=top_k,
|
||||
probs=True) # + SmallConst
|
||||
|
||||
if torch.sum(log_probs) <= 1:
|
||||
log_probs = log_probs / torch.sum(log_probs)
|
||||
|
||||
else:
|
||||
logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
||||
logits = top_k_filter(logits, k=top_k) # + SmallConst
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
|
||||
if sample:
|
||||
@ -673,16 +767,16 @@ def run_model():
|
||||
|
||||
collect_gen = dict()
|
||||
current_index = 0
|
||||
for out in seq:
|
||||
for tokenized_cond_text in seq:
|
||||
|
||||
text = TOKENIZER.decode(out)
|
||||
text = TOKENIZER.decode(tokenized_cond_text)
|
||||
print("=" * 40 + " Prefix of sentence " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
|
||||
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb(
|
||||
model=model, args=args, context=out,
|
||||
device=device)
|
||||
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
|
||||
model=model, context=tokenized_cond_text, device=device, **vars(args)
|
||||
)
|
||||
|
||||
text_whole = TOKENIZER.decode(out1.tolist()[0])
|
||||
|
||||
@ -712,20 +806,20 @@ def run_model():
|
||||
import colorama
|
||||
|
||||
text_whole = ''
|
||||
for out in output_tokens:
|
||||
if out in keyword_tokens:
|
||||
for tokenized_cond_text in output_tokens:
|
||||
if tokenized_cond_text in keyword_tokens:
|
||||
text_whole += '%s%s%s' % (
|
||||
colorama.Fore.GREEN, TOKENIZER.decode([out]),
|
||||
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
|
||||
colorama.Style.RESET_ALL)
|
||||
else:
|
||||
text_whole += TOKENIZER.decode([out])
|
||||
text_whole += TOKENIZER.decode([tokenized_cond_text])
|
||||
else:
|
||||
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
|
||||
|
||||
print(text_whole)
|
||||
print("=" * 80)
|
||||
|
||||
collect_gen[current_index] = [out, out_perturb, out1]
|
||||
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
|
||||
|
||||
current_index = current_index + 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user