mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Cleaned full_text_generation. Identical output as before.
This commit is contained in:
parent
6c9c131780
commit
08c6e456a3
@ -401,74 +401,6 @@ def full_text_generation(
|
||||
device
|
||||
)
|
||||
|
||||
# if args.discrim == 'clickbait':
|
||||
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
|
||||
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
|
||||
# classifier.eval()
|
||||
# args.label_class = 1 # clickbaity
|
||||
#
|
||||
# elif args.discrim == 'sentiment':
|
||||
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
|
||||
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
|
||||
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
|
||||
# classifier.eval()
|
||||
# if args.label_class < 0:
|
||||
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
|
||||
# #args.label_class = 2 # very pos
|
||||
# #args.label_class = 3 # very neg
|
||||
#
|
||||
# elif args.discrim == 'toxicity':
|
||||
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
|
||||
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
|
||||
# classifier.eval()
|
||||
# args.label_class = 0 # not toxic
|
||||
#
|
||||
# elif args.discrim == 'generic':
|
||||
# if args.discrim_weights is None:
|
||||
# raise ValueError('When using a generic discriminator, '
|
||||
# 'discrim_weights need to be specified')
|
||||
# if args.discrim_meta is None:
|
||||
# raise ValueError('When using a generic discriminator, '
|
||||
# 'discrim_meta need to be specified')
|
||||
#
|
||||
# with open(args.discrim_meta, 'r') as discrim_meta_file:
|
||||
# meta = json.load(discrim_meta_file)
|
||||
#
|
||||
# classifier = ClassificationHead(
|
||||
# class_size=meta['class_size'],
|
||||
# embed_size=meta['embed_size'],
|
||||
# # todo add tokenizer from meta
|
||||
# ).to(device)
|
||||
# classifier.load_state_dict(torch.load(args.discrim_weights))
|
||||
# classifier.eval()
|
||||
# if args.label_class == -1:
|
||||
# args.label_class = meta['default_class']
|
||||
#
|
||||
# else:
|
||||
# classifier = None
|
||||
|
||||
# Get tokens for the list of positive words
|
||||
def list_tokens(word_list):
|
||||
token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in
|
||||
word_list]
|
||||
# token_list = []
|
||||
# for word in word_list:
|
||||
# token_list.append(TOKENIZER.encode(" " + word))
|
||||
return token_list
|
||||
|
||||
# good_index = []
|
||||
# if args.bag_of_words:
|
||||
# bags_of_words = args.bag_of_words.split(";")
|
||||
# for wordlist in bags_of_words:
|
||||
# with open(wordlist, "r") as f:
|
||||
# words = f.read().strip()
|
||||
# words = words.split('\n')
|
||||
# good_index.append(list_tokens(words))
|
||||
#
|
||||
# for good_list in good_index:
|
||||
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
||||
|
||||
bow_indices = []
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
||||
@ -486,9 +418,9 @@ def full_text_generation(
|
||||
print("Using PPLM-Discrim")
|
||||
|
||||
else:
|
||||
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
|
||||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
original, _, _ = generate_text_pplm(
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model,
|
||||
context=context,
|
||||
device=device,
|
||||
@ -497,12 +429,12 @@ def full_text_generation(
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
perturbed_list = []
|
||||
discrim_loss_list = []
|
||||
loss_in_time_list = []
|
||||
pert_gen_tok_texts = []
|
||||
discrim_losses = []
|
||||
losses_in_time = []
|
||||
|
||||
for i in range(num_samples):
|
||||
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
|
||||
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
||||
model=model,
|
||||
context=context,
|
||||
device=device,
|
||||
@ -525,14 +457,14 @@ def full_text_generation(
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
perturbed_list.append(perturbed)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
|
||||
loss_in_time_list.append(loss_in_time)
|
||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||
losses_in_time.append(loss_in_time)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list
|
||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
|
||||
|
||||
def generate_text_pplm(
|
||||
@ -821,11 +753,14 @@ def run_model():
|
||||
|
||||
generated_texts = []
|
||||
|
||||
bow_words = set()
|
||||
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
||||
for bow_list in bow_indices:
|
||||
filtered = list(filter(lambda x: len(x) <= 1, bow_list))
|
||||
bow_words.update(w[0] for w in filtered)
|
||||
bow_word_ids = set()
|
||||
if args.bag_of_words and args.colorama:
|
||||
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
||||
for single_bow_list in bow_indices:
|
||||
# filtering all words in the list composed of more than 1 token
|
||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||
# w[0] because we are sure w has only 1 item because previous fitler
|
||||
bow_word_ids.update(w[0] for w in filtered)
|
||||
|
||||
# iterate through the perturbed texts
|
||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||
@ -836,7 +771,7 @@ def run_model():
|
||||
|
||||
pert_gen_text = ''
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_words:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += '{}{}{}'.format(
|
||||
colorama.Fore.RED,
|
||||
TOKENIZER.decode([word_id]),
|
||||
|
Loading…
Reference in New Issue
Block a user