Cleaned full_text_generation. Identical output as before.

This commit is contained in:
piero 2019-11-27 17:48:46 -08:00 committed by Julien Chaumond
parent 6c9c131780
commit 08c6e456a3

View File

@ -401,74 +401,6 @@ def full_text_generation(
device
)
# if args.discrim == 'clickbait':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def list_tokens(word_list):
token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in
word_list]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices = []
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
@ -486,9 +418,9 @@ def full_text_generation(
print("Using PPLM-Discrim")
else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
raise Exception("Specify either a bag of words or a discriminator")
original, _, _ = generate_text_pplm(
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model,
context=context,
device=device,
@ -497,12 +429,12 @@ def full_text_generation(
)
torch.cuda.empty_cache()
perturbed_list = []
discrim_loss_list = []
loss_in_time_list = []
pert_gen_tok_texts = []
discrim_losses = []
losses_in_time = []
for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
context=context,
device=device,
@ -525,14 +457,14 @@ def full_text_generation(
decay=decay,
gamma=gamma,
)
perturbed_list.append(perturbed)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
loss_in_time_list.append(loss_in_time)
discrim_losses.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time)
torch.cuda.empty_cache()
return original, perturbed_list, discrim_loss_list, loss_in_time_list
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
def generate_text_pplm(
@ -821,11 +753,14 @@ def run_model():
generated_texts = []
bow_words = set()
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
for bow_list in bow_indices:
filtered = list(filter(lambda x: len(x) <= 1, bow_list))
bow_words.update(w[0] for w in filtered)
bow_word_ids = set()
if args.bag_of_words and args.colorama:
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids.update(w[0] for w in filtered)
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
@ -836,7 +771,7 @@ def run_model():
pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_words:
if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
TOKENIZER.decode([word_id]),