Imrpovements: model_path renamed pretrained_model, tokenizer loaded from pretrained_model, pretrained_model set to discriminator's when discrim is specified, sample = False by default but cli parameter introduced. To obtain identical samples call the cli with --sample

This commit is contained in:
w4nderlust 2019-11-29 19:59:02 -08:00 committed by Julien Chaumond
parent 75904dae66
commit f10b925015

View File

@ -43,7 +43,6 @@ PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = {
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
"default_class": 1,
"pretrained_model": "gpt2-medium",
},
"sentiment": {
"url": "http://s.yosinski.com/SST_classifier_head.pt",
@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"very_positive": 2, "very_negative": 3},
"default_class": 3,
"pretrained_model": "gpt2-medium",
},
"toxicity": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"non_toxic": 0, "toxic": 1},
"default_class": 0,
"pretrained_model": "gpt2-medium",
},
}
@ -345,8 +347,9 @@ def get_classifier(
return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
List[List[int]]]:
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[
List[List[int]]]:
bow_indices = []
for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
with open(filepath, "r") as f:
words = f.read().strip().split("\n")
bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
words])
return bow_indices
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
if bow_indices is None:
return None
@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors
@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
def full_text_generation(
model,
tokenizer,
context=None,
num_samples=1,
device="cuda",
sample=True,
sample=False,
discrim=None,
class_label=None,
bag_of_words=None,
@ -407,7 +411,8 @@ def full_text_generation(
bow_indices = []
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
@ -426,9 +431,11 @@ def full_text_generation(
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False
)
if device == 'cuda':
@ -441,6 +448,7 @@ def full_text_generation(
for i in range(num_samples):
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
sample=sample,
@ -475,10 +483,11 @@ def full_text_generation(
def generate_text_pplm(
model,
tokenizer,
context=None,
past=None,
device="cuda",
sample=True,
sample=False,
perturb=True,
classifier=None,
class_label=None,
@ -504,7 +513,8 @@ def generate_text_pplm(
)
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
device)
grad_norms = None
last = None
@ -612,7 +622,7 @@ def generate_text_pplm(
else torch.cat((output_so_far, last), dim=1)
)
print(TOKENIZER.decode(output_so_far.tolist()[0]))
print(tokenizer.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time
@ -631,10 +641,167 @@ def set_generic_model_params(discrim_weights, discrim_meta):
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_model():
def run_pplm_example(
pretrained_model="gpt2-medium",
cond_text="",
uncond=False,
num_samples=1,
bag_of_words=None,
discrim=None,
discrim_weights=None,
discrim_meta=None,
class_label=-1,
length=100,
stepsize=0.02,
temperature=1.0,
top_k=10,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
seed=0,
no_cuda=False,
colorama=False
):
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# set the device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
if discrim == 'generic':
set_generic_model_params(discrim_weights, discrim_meta)
if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model"
]
print("discrim = {}, setting pretrained_model "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
pretrained_model,
output_hidden_states=True
)
model.to(device)
model.eval()
# load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
# Freeze GPT-2 weights
for param in model.parameters():
param.requires_grad = False
# figure out conditioning text
if uncond:
tokenized_cond_text = tokenizer.encode(
[tokenizer.bos_token]
)
else:
raw_text = cond_text
while not raw_text:
print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ")
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
print("= Prefix of sentence =")
print(tokenizer.decode(tokenized_cond_text))
print()
# generate unperturbed and perturbed texts
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model,
tokenizer=tokenizer,
context=tokenized_cond_text,
device=device,
num_samples=num_samples,
bag_of_words=bag_of_words,
discrim=discrim,
class_label=class_label,
length=length,
stepsize=stepsize,
temperature=temperature,
top_k=top_k,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
)
# untokenize unperturbed text
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80)
print("= Unperturbed generated text =")
print(unpert_gen_text)
print()
generated_texts = []
bow_word_ids = set()
if bag_of_words and colorama:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids.update(w[0] for w in filtered)
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if colorama:
import colorama
pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
tokenizer.decode([word_id]),
colorama.Style.RESET_ALL
)
else:
pert_gen_text += tokenizer.decode([word_id])
else:
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text)
print()
except:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
"--pretrained_model",
"-M",
type=str,
default="gpt2-medium",
@ -675,6 +842,10 @@ def run_model():
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
@ -711,105 +882,4 @@ def run_model():
help="colors keywords")
args = parser.parse_args()
# set Random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
if args.discrim == 'generic':
set_generic_model_params(args.discrim_weights, args.discrim_meta)
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
args.model_path,
output_hidden_states=True
)
model.to(device)
model.eval()
# Freeze GPT-2 weights
for param in model.parameters():
param.requires_grad = False
# figure out conditioning text
if args.uncond:
tokenized_cond_text = TOKENIZER.encode(
[TOKENIZER.bos_token]
)
else:
raw_text = args.cond_text
while not raw_text:
print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ")
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
print("= Prefix of sentence =")
print(TOKENIZER.decode(tokenized_cond_text))
print()
# generate unperturbed and perturbed texts
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
)
# untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80)
print("= Unperturbed generated text =")
print(unpert_gen_text)
print()
generated_texts = []
bow_word_ids = set()
if args.bag_of_words and args.colorama:
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids.update(w[0] for w in filtered)
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if args.colorama:
import colorama
pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
TOKENIZER.decode([word_id]),
colorama.Style.RESET_ALL
)
else:
pert_gen_text += TOKENIZER.decode([word_id])
else:
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text)
print()
except:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return
if __name__ == '__main__':
run_model()
run_pplm_example(**vars(args))