More cleanup for run_model. Identical output as before.

This commit is contained in:
piero 2019-11-27 17:27:39 -08:00 committed by Julien Chaumond
parent 7ffe47c888
commit 6c9c131780

View File

@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
@ -129,8 +128,7 @@ def perturb_past(
decay=False,
gamma=1.5,
):
#def perturb_past(past, model, prev, classifier, good_index=None,
# def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
@ -237,7 +235,7 @@ def perturb_past(
future_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + horizon_length))
current_length + 1 + horizon_length))
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
words])
#bow_words = set()
#for bow_list in bow_indices:
# bow_list = list(filter(lambda x: len(x) <= 1, bow_list))
# bow_words.update(
# (TOKENIZER.decode(word).strip(), word) for word in bow_list)
return bow_indices
@ -368,28 +373,28 @@ def build_bows_one_hot_vectors(bow_indices):
def full_text_generation(
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier(
discrim,
label_class,
@ -465,15 +470,9 @@ def full_text_generation(
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices = []
actual_words = None
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list]
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
loss_type = PPLM_BOW_DISCRIM
@ -533,8 +532,7 @@ def full_text_generation(
torch.cuda.empty_cache()
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
return original, perturbed_list, discrim_loss_list, loss_in_time_list
def generate_text_pplm(
@ -611,25 +609,25 @@ def generate_text_pplm(
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
past,
model,
prev,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
past,
model,
prev,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter)
# Piero modified model call
@ -666,7 +664,7 @@ def generate_text_pplm(
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs = ((log_probs ** gm_scale) * (
unpert_logits ** (1 - gm_scale))) # + SmallConst
unpert_logits ** (1 - gm_scale))) # + SmallConst
log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst
@ -696,53 +694,88 @@ def generate_text_pplm(
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium',
help='pretrained model name or path to local checkpoint')
parser.add_argument('--bag-of-words', '-B', type=str, default=None,
help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;')
parser.add_argument('--discrim', '-D', type=str, default=None,
choices=(
'clickbait', 'sentiment', 'toxicity', 'generic'),
help='Discriminator to use for loss-type 2')
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument('--label_class', type=int, default=-1,
help='Class label used for the discriminator')
parser.add_argument('--stepsize', type=float, default=0.02)
parser.add_argument(
"--model_path",
"-M",
type=str,
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--bag_of_words",
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity"),
help="Discriminator to use for loss-type 2",
)
parser.add_argument(
"--label_class",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument('--nocuda', action='store_true', help='no cuda')
parser.add_argument('--uncond', action='store_true',
help='Generate from end-of-text as prefix')
parser.add_argument("--cond_text", type=str, default='The lake',
help='Prefix texts to condition on')
parser.add_argument('--num_iterations', type=int, default=3)
parser.add_argument('--grad_length', type=int, default=10000)
parser.add_argument('--num_samples', type=int, default=1,
help='Number of samples to generate from the modified latents')
parser.add_argument('--horizon_length', type=int, default=1,
help='Length of future to optimize over')
# parser.add_argument('--force-token', action='store_true', help='no cuda')
parser.add_argument('--window_length', type=int, default=0,
help='Length of past which is being optimizer; 0 corresponds to infinite window length')
parser.add_argument('--decay', action='store_true',
help='whether to decay or not')
parser.add_argument('--gamma', type=float, default=1.5)
parser.add_argument('--colorama', action='store_true', help='no cuda')
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true", help="colors keywords")
args = parser.parse_args()
# set Random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = 'cpu' if args.nocuda else 'cuda'
# set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
args.model_path,
output_hidden_states=True
@ -753,76 +786,77 @@ def run_model():
# Freeze GPT-2 weights
for param in model.parameters():
param.requires_grad = False
pass
# figure out conditioning text
if args.uncond:
seq = [[50256, 50256]]
tokenized_cond_text = TOKENIZER.encode(
[TOKENIZER.bos_token]
)
else:
raw_text = args.cond_text
while not raw_text:
print('Did you forget to add `--cond-text`? ')
print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ")
seq = [[50256] + TOKENIZER.encode(raw_text)]
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
collect_gen = dict()
current_index = 0
for tokenized_cond_text in seq:
print("= Prefix of sentence =")
print(TOKENIZER.decode(tokenized_cond_text))
print()
text = TOKENIZER.decode(tokenized_cond_text)
print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text)
print("=" * 80)
# generate unperturbed and perturbed texts
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
)
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
)
text_whole = TOKENIZER.decode(out1.tolist()[0])
# untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80)
print("=" * 40 + " Whole sentence (Original)" + "=" * 40)
print(text_whole)
print("=" * 80)
print("=" * 80)
print("= Unperturbed generated text =")
print(unpert_gen_text)
print()
out_perturb_copy = out_perturb
generated_texts = []
for out_perturb in out_perturb_copy:
# try:
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
# print(text_whole)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
keyword_tokens = [aa[-1][0] for aa in
actual_words] if actual_words else []
output_tokens = out_perturb.tolist()[0]
bow_words = set()
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
for bow_list in bow_indices:
filtered = list(filter(lambda x: len(x) <= 1, bow_list))
bow_words.update(w[0] for w in filtered)
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if args.colorama:
import colorama
text_whole = ''
for tokenized_cond_text in output_tokens:
if tokenized_cond_text in keyword_tokens:
text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
colorama.Style.RESET_ALL)
pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_words:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
TOKENIZER.decode([word_id]),
colorama.Style.RESET_ALL
)
else:
text_whole += TOKENIZER.decode([tokenized_cond_text])
pert_gen_text += TOKENIZER.decode([word_id])
else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
print(text_whole)
print("=" * 80)
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
current_index = current_index + 1
print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text)
print()
except:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return