mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
More cleanup for run_model. Identical output as before.
This commit is contained in:
parent
7ffe47c888
commit
6c9c131780
@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
|
||||
PPLM_BOW = 1
|
||||
PPLM_DISCRIM = 2
|
||||
PPLM_BOW_DISCRIM = 3
|
||||
@ -129,8 +128,7 @@ def perturb_past(
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
):
|
||||
|
||||
#def perturb_past(past, model, prev, classifier, good_index=None,
|
||||
# def perturb_past(past, model, prev, classifier, good_index=None,
|
||||
# stepsize=0.01, vocab_size=50257,
|
||||
# original_probs=None, accumulated_hidden=None, true_past=None,
|
||||
# grad_norms=None):
|
||||
@ -237,7 +235,7 @@ def perturb_past(
|
||||
future_hidden, dim=1)
|
||||
|
||||
predicted_sentiment = classifier(new_accumulated_hidden / (
|
||||
current_length + 1 + horizon_length))
|
||||
current_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
dtype=torch.long)
|
||||
@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
|
||||
bow_indices.append(
|
||||
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
|
||||
words])
|
||||
|
||||
#bow_words = set()
|
||||
#for bow_list in bow_indices:
|
||||
# bow_list = list(filter(lambda x: len(x) <= 1, bow_list))
|
||||
# bow_words.update(
|
||||
# (TOKENIZER.decode(word).strip(), word) for word in bow_list)
|
||||
|
||||
return bow_indices
|
||||
|
||||
|
||||
@ -368,28 +373,28 @@ def build_bows_one_hot_vectors(bow_indices):
|
||||
|
||||
|
||||
def full_text_generation(
|
||||
model,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
sample=True,
|
||||
discrim=None,
|
||||
label_class=None,
|
||||
bag_of_words=None,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
**kwargs
|
||||
):
|
||||
model,
|
||||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
sample=True,
|
||||
discrim=None,
|
||||
label_class=None,
|
||||
bag_of_words=None,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(
|
||||
discrim,
|
||||
label_class,
|
||||
@ -465,15 +470,9 @@ def full_text_generation(
|
||||
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
||||
|
||||
bow_indices = []
|
||||
actual_words = None
|
||||
if bag_of_words:
|
||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
||||
|
||||
for good_list in bow_indices:
|
||||
good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
||||
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
|
||||
good_list]
|
||||
|
||||
if bag_of_words and classifier:
|
||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||
loss_type = PPLM_BOW_DISCRIM
|
||||
@ -533,8 +532,7 @@ def full_text_generation(
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
|
||||
|
||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list
|
||||
|
||||
|
||||
def generate_text_pplm(
|
||||
@ -611,25 +609,25 @@ def generate_text_pplm(
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
label_class=label_class,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
label_class=label_class,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
loss_in_time.append(loss_per_iter)
|
||||
|
||||
# Piero modified model call
|
||||
@ -666,7 +664,7 @@ def generate_text_pplm(
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
|
||||
log_probs = ((log_probs ** gm_scale) * (
|
||||
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
||||
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
||||
|
||||
log_probs = top_k_filter(log_probs, k=top_k,
|
||||
probs=True) # + SmallConst
|
||||
@ -696,53 +694,88 @@ def generate_text_pplm(
|
||||
|
||||
def run_model():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium',
|
||||
help='pretrained model name or path to local checkpoint')
|
||||
parser.add_argument('--bag-of-words', '-B', type=str, default=None,
|
||||
help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;')
|
||||
parser.add_argument('--discrim', '-D', type=str, default=None,
|
||||
choices=(
|
||||
'clickbait', 'sentiment', 'toxicity', 'generic'),
|
||||
help='Discriminator to use for loss-type 2')
|
||||
parser.add_argument('--discrim_weights', type=str, default=None,
|
||||
help='Weights for the generic discriminator')
|
||||
parser.add_argument('--discrim_meta', type=str, default=None,
|
||||
help='Meta information for the generic discriminator')
|
||||
parser.add_argument('--label_class', type=int, default=-1,
|
||||
help='Class label used for the discriminator')
|
||||
parser.add_argument('--stepsize', type=float, default=0.02)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
"-M",
|
||||
type=str,
|
||||
default="gpt2-medium",
|
||||
help="pretrained model name or path to local checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
"-B",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Bags of words used for PPLM-BoW. "
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim",
|
||||
"-D",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=("clickbait", "sentiment", "toxicity"),
|
||||
help="Discriminator to use for loss-type 2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_class",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument('--nocuda', action='store_true', help='no cuda')
|
||||
parser.add_argument('--uncond', action='store_true',
|
||||
help='Generate from end-of-text as prefix')
|
||||
parser.add_argument("--cond_text", type=str, default='The lake',
|
||||
help='Prefix texts to condition on')
|
||||
parser.add_argument('--num_iterations', type=int, default=3)
|
||||
parser.add_argument('--grad_length', type=int, default=10000)
|
||||
parser.add_argument('--num_samples', type=int, default=1,
|
||||
help='Number of samples to generate from the modified latents')
|
||||
parser.add_argument('--horizon_length', type=int, default=1,
|
||||
help='Length of future to optimize over')
|
||||
# parser.add_argument('--force-token', action='store_true', help='no cuda')
|
||||
parser.add_argument('--window_length', type=int, default=0,
|
||||
help='Length of past which is being optimizer; 0 corresponds to infinite window length')
|
||||
parser.add_argument('--decay', action='store_true',
|
||||
help='whether to decay or not')
|
||||
parser.add_argument('--gamma', type=float, default=1.5)
|
||||
parser.add_argument('--colorama', action='store_true', help='no cuda')
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument(
|
||||
"--uncond", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_text", type=str, default="The lake",
|
||||
help="Prefix texts to condition on"
|
||||
)
|
||||
parser.add_argument("--num_iterations", type=int, default=3)
|
||||
parser.add_argument("--grad_length", type=int, default=10000)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of past which is being optimized; "
|
||||
"0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true",
|
||||
help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# set Random seed
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
device = 'cpu' if args.nocuda else 'cuda'
|
||||
# set the device
|
||||
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
args.model_path,
|
||||
output_hidden_states=True
|
||||
@ -753,76 +786,77 @@ def run_model():
|
||||
# Freeze GPT-2 weights
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
pass
|
||||
|
||||
# figure out conditioning text
|
||||
if args.uncond:
|
||||
seq = [[50256, 50256]]
|
||||
|
||||
tokenized_cond_text = TOKENIZER.encode(
|
||||
[TOKENIZER.bos_token]
|
||||
)
|
||||
else:
|
||||
raw_text = args.cond_text
|
||||
while not raw_text:
|
||||
print('Did you forget to add `--cond-text`? ')
|
||||
print("Did you forget to add `--cond_text`? ")
|
||||
raw_text = input("Model prompt >>> ")
|
||||
seq = [[50256] + TOKENIZER.encode(raw_text)]
|
||||
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
|
||||
|
||||
collect_gen = dict()
|
||||
current_index = 0
|
||||
for tokenized_cond_text in seq:
|
||||
print("= Prefix of sentence =")
|
||||
print(TOKENIZER.decode(tokenized_cond_text))
|
||||
print()
|
||||
|
||||
text = TOKENIZER.decode(tokenized_cond_text)
|
||||
print("=" * 40 + " Prefix of sentence " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
# generate unperturbed and perturbed texts
|
||||
|
||||
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
|
||||
model=model, context=tokenized_cond_text, device=device, **vars(args)
|
||||
)
|
||||
# full_text_generation returns:
|
||||
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
|
||||
model=model, context=tokenized_cond_text, device=device, **vars(args)
|
||||
)
|
||||
|
||||
text_whole = TOKENIZER.decode(out1.tolist()[0])
|
||||
# untokenize unperturbed text
|
||||
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
|
||||
|
||||
print("=" * 80)
|
||||
print("=" * 40 + " Whole sentence (Original)" + "=" * 40)
|
||||
print(text_whole)
|
||||
print("=" * 80)
|
||||
print("=" * 80)
|
||||
print("= Unperturbed generated text =")
|
||||
print(unpert_gen_text)
|
||||
print()
|
||||
|
||||
out_perturb_copy = out_perturb
|
||||
generated_texts = []
|
||||
|
||||
for out_perturb in out_perturb_copy:
|
||||
# try:
|
||||
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
|
||||
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
|
||||
# print(text_whole)
|
||||
# print("=" * 80)
|
||||
# except:
|
||||
# pass
|
||||
# collect_gen[current_index] = [out, out_perturb, out1]
|
||||
## Save the prefix, perturbed seq, original seq for each index
|
||||
print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
|
||||
keyword_tokens = [aa[-1][0] for aa in
|
||||
actual_words] if actual_words else []
|
||||
output_tokens = out_perturb.tolist()[0]
|
||||
bow_words = set()
|
||||
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
||||
for bow_list in bow_indices:
|
||||
filtered = list(filter(lambda x: len(x) <= 1, bow_list))
|
||||
bow_words.update(w[0] for w in filtered)
|
||||
|
||||
# iterate through the perturbed texts
|
||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||
try:
|
||||
# untokenize unperturbed text
|
||||
if args.colorama:
|
||||
import colorama
|
||||
|
||||
text_whole = ''
|
||||
for tokenized_cond_text in output_tokens:
|
||||
if tokenized_cond_text in keyword_tokens:
|
||||
text_whole += '%s%s%s' % (
|
||||
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
|
||||
colorama.Style.RESET_ALL)
|
||||
pert_gen_text = ''
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_words:
|
||||
pert_gen_text += '{}{}{}'.format(
|
||||
colorama.Fore.RED,
|
||||
TOKENIZER.decode([word_id]),
|
||||
colorama.Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
text_whole += TOKENIZER.decode([tokenized_cond_text])
|
||||
pert_gen_text += TOKENIZER.decode([word_id])
|
||||
else:
|
||||
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
|
||||
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
|
||||
|
||||
print(text_whole)
|
||||
print("=" * 80)
|
||||
|
||||
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
|
||||
|
||||
current_index = current_index + 1
|
||||
print("= Perturbed generated text {} =".format(i + 1))
|
||||
print(pert_gen_text)
|
||||
print()
|
||||
except:
|
||||
pass
|
||||
|
||||
# keep the prefix, perturbed seq, original seq for each index
|
||||
generated_texts.append(
|
||||
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user