More cleanup for run_model. Identical output as before.

This commit is contained in:
piero 2019-11-27 17:27:39 -08:00 committed by Julien Chaumond
parent 7ffe47c888
commit 6c9c131780

View File

@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel from transformers.modeling_gpt2 import GPT2LMHeadModel
PPLM_BOW = 1 PPLM_BOW = 1
PPLM_DISCRIM = 2 PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3 PPLM_BOW_DISCRIM = 3
@ -129,8 +128,7 @@ def perturb_past(
decay=False, decay=False,
gamma=1.5, gamma=1.5,
): ):
# def perturb_past(past, model, prev, classifier, good_index=None,
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257, # stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None, # original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None): # grad_norms=None):
@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
bow_indices.append( bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in [TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
words]) words])
#bow_words = set()
#for bow_list in bow_indices:
# bow_list = list(filter(lambda x: len(x) <= 1, bow_list))
# bow_words.update(
# (TOKENIZER.decode(word).strip(), word) for word in bow_list)
return bow_indices return bow_indices
@ -389,7 +394,7 @@ def full_text_generation(
decay=False, decay=False,
gamma=1.5, gamma=1.5,
**kwargs **kwargs
): ):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
discrim, discrim,
label_class, label_class,
@ -465,15 +470,9 @@ def full_text_generation(
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices = [] bow_indices = []
actual_words = None
if bag_of_words: if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list]
if bag_of_words and classifier: if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
loss_type = PPLM_BOW_DISCRIM loss_type = PPLM_BOW_DISCRIM
@ -533,8 +532,7 @@ def full_text_generation(
torch.cuda.empty_cache() torch.cuda.empty_cache()
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words return original, perturbed_list, discrim_loss_list, loss_in_time_list
def generate_text_pplm( def generate_text_pplm(
@ -696,53 +694,88 @@ def generate_text_pplm(
def run_model(): def run_model():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium', parser.add_argument(
help='pretrained model name or path to local checkpoint') "--model_path",
parser.add_argument('--bag-of-words', '-B', type=str, default=None, "-M",
help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;') type=str,
parser.add_argument('--discrim', '-D', type=str, default=None, default="gpt2-medium",
choices=( help="pretrained model name or path to local checkpoint",
'clickbait', 'sentiment', 'toxicity', 'generic'), )
help='Discriminator to use for loss-type 2') parser.add_argument(
parser.add_argument('--discrim_weights', type=str, default=None, "--bag_of_words",
help='Weights for the generic discriminator') "-B",
parser.add_argument('--discrim_meta', type=str, default=None, type=str,
help='Meta information for the generic discriminator') default=None,
parser.add_argument('--label_class', type=int, default=-1, help="Bags of words used for PPLM-BoW. "
help='Class label used for the discriminator') "Either a BOW id (see list in code) or a filepath. "
parser.add_argument('--stepsize', type=float, default=0.02) "Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity"),
help="Discriminator to use for loss-type 2",
)
parser.add_argument(
"--label_class",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100) parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01) parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument('--nocuda', action='store_true', help='no cuda') parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument('--uncond', action='store_true', parser.add_argument(
help='Generate from end-of-text as prefix') "--uncond", action="store_true",
parser.add_argument("--cond_text", type=str, default='The lake', help="Generate from end-of-text as prefix"
help='Prefix texts to condition on') )
parser.add_argument('--num_iterations', type=int, default=3) parser.add_argument(
parser.add_argument('--grad_length', type=int, default=10000) "--cond_text", type=str, default="The lake",
parser.add_argument('--num_samples', type=int, default=1, help="Prefix texts to condition on"
help='Number of samples to generate from the modified latents') )
parser.add_argument('--horizon_length', type=int, default=1, parser.add_argument("--num_iterations", type=int, default=3)
help='Length of future to optimize over') parser.add_argument("--grad_length", type=int, default=10000)
# parser.add_argument('--force-token', action='store_true', help='no cuda') parser.add_argument(
parser.add_argument('--window_length', type=int, default=0, "--num_samples",
help='Length of past which is being optimizer; 0 corresponds to infinite window length') type=int,
parser.add_argument('--decay', action='store_true', default=1,
help='whether to decay or not') help="Number of samples to generate from the modified latents",
parser.add_argument('--gamma', type=float, default=1.5) )
parser.add_argument('--colorama', action='store_true', help='no cuda') parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true", help="colors keywords")
args = parser.parse_args() args = parser.parse_args()
# set Random seed
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
device = 'cpu' if args.nocuda else 'cuda' # set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
# load pretrained model
model = GPT2LMHeadModel.from_pretrained( model = GPT2LMHeadModel.from_pretrained(
args.model_path, args.model_path,
output_hidden_states=True output_hidden_states=True
@ -753,76 +786,77 @@ def run_model():
# Freeze GPT-2 weights # Freeze GPT-2 weights
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
pass
# figure out conditioning text
if args.uncond: if args.uncond:
seq = [[50256, 50256]] tokenized_cond_text = TOKENIZER.encode(
[TOKENIZER.bos_token]
)
else: else:
raw_text = args.cond_text raw_text = args.cond_text
while not raw_text: while not raw_text:
print('Did you forget to add `--cond-text`? ') print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
seq = [[50256] + TOKENIZER.encode(raw_text)] tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
collect_gen = dict() print("= Prefix of sentence =")
current_index = 0 print(TOKENIZER.decode(tokenized_cond_text))
for tokenized_cond_text in seq: print()
text = TOKENIZER.decode(tokenized_cond_text) # generate unperturbed and perturbed texts
print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text)
print("=" * 80)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation( # full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args) model=model, context=tokenized_cond_text, device=device, **vars(args)
) )
text_whole = TOKENIZER.decode(out1.tolist()[0]) # untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80) print("=" * 80)
print("=" * 40 + " Whole sentence (Original)" + "=" * 40) print("= Unperturbed generated text =")
print(text_whole) print(unpert_gen_text)
print("=" * 80) print()
out_perturb_copy = out_perturb generated_texts = []
for out_perturb in out_perturb_copy: bow_words = set()
# try: bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) for bow_list in bow_indices:
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) filtered = list(filter(lambda x: len(x) <= 1, bow_list))
# print(text_whole) bow_words.update(w[0] for w in filtered)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
keyword_tokens = [aa[-1][0] for aa in
actual_words] if actual_words else []
output_tokens = out_perturb.tolist()[0]
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if args.colorama: if args.colorama:
import colorama import colorama
text_whole = '' pert_gen_text = ''
for tokenized_cond_text in output_tokens: for word_id in pert_gen_tok_text.tolist()[0]:
if tokenized_cond_text in keyword_tokens: if word_id in bow_words:
text_whole += '%s%s%s' % ( pert_gen_text += '{}{}{}'.format(
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]), colorama.Fore.RED,
colorama.Style.RESET_ALL) TOKENIZER.decode([word_id]),
colorama.Style.RESET_ALL
)
else: else:
text_whole += TOKENIZER.decode([tokenized_cond_text]) pert_gen_text += TOKENIZER.decode([word_id])
else: else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
print(text_whole) print("= Perturbed generated text {} =".format(i + 1))
print("=" * 80) print(pert_gen_text)
print()
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1] except:
pass
current_index = current_index + 1
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return return