diff --git a/examples/research_projects/vqgan-clip/README.md b/examples/research_projects/vqgan-clip/README.md new file mode 100644 index 00000000000..aef95093542 --- /dev/null +++ b/examples/research_projects/vqgan-clip/README.md @@ -0,0 +1,70 @@ +# Simple VQGAN CLIP + +Author: @ErwannMillon + +This is a very simple VQGAN-CLIP implementation that was built as a part of the Face Editor project . This simplified version allows you to generate or edit images using text with just three lines of code. For a more full featured implementation with masking, more advanced losses, and a full GUI, check out the Face Editor project. + +By default this uses a CelebA checkpoint (for generating/editing faces), but also has an imagenet checkpoint that can be loaded by specifying vqgan_config and vqgan_checkpoint when instantiating VQGAN_CLIP. + +Learning rate and iterations can be set by modifying vqgan_clip.lr and vqgan_clip.iterations . + +You can edit images by passing `image_path` to the generate function. +See the generate function's docstring to learn more about how to format prompts. + +## Usage +The easiest way to test this out is by using the Colab demo + +To install locally: +- Clone this repo +- Install git-lfs (ubuntu: sudo apt-get install git-lfs , MacOS: brew install git-lfs) + +In the root of the repo run: + +``` +conda create -n vqganclip python=3.8 +conda activate vqganclip +git-lfs install +git clone https://huggingface.co/datasets/erwann/face_editor_model_ckpt model_checkpoints +pip install -r requirements.txt +``` + +### Generate new images +``` +from VQGAN_CLIP import VQGAN_CLIP +vqgan_clip = VQGAN_CLIP() +vqgan_clip.generate("a picture of a smiling woman") +``` + +### Edit an image +To get a test image, run +`git clone https://huggingface.co/datasets/erwann/vqgan-clip-pic test_images` + +To edit: +``` +from VQGAN_CLIP import VQGAN_CLIP +vqgan_clip = VQGAN_CLIP() + +vqgan_clip.lr = .07 +vqgan_clip.iterations = 15 +vqgan_clip.generate( + pos_prompts= ["a picture of a beautiful asian woman", "a picture of a woman from Japan"], + neg_prompts=["a picture of an Indian person", "a picture of a white person"], + image_path="./test_images/face.jpeg", + show_intermediate=True, + save_intermediate=True, +) +``` + +### Make an animation from the most recent generation +`vqgan_clip.make_animation()` + +## Features: +- Positive and negative prompts +- Multiple prompts +- Prompt Weights +- Creating GIF animations of the transformations +- Wandb logging + + + diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py new file mode 100644 index 00000000000..c936148c554 --- /dev/null +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -0,0 +1,268 @@ +import os +from glob import glob + +import torch +import torchvision +from PIL import Image +from torch import nn + +import imageio +import wandb +from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan +from loaders import load_vqgan +from transformers import CLIPModel, CLIPTokenizerFast +from utils import get_device, get_timestamp, show_pil + + +class ProcessorGradientFlow: + """ + This wraps the huggingface CLIP processor to allow backprop through the image processing step. + The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow. + We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors. + """ + + def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None: + self.device = device + self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model) + self.image_mean = [0.48145466, 0.4578275, 0.40821073] + self.image_std = [0.26862954, 0.26130258, 0.27577711] + self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std) + self.resize = torchvision.transforms.Resize(224) + self.center_crop = torchvision.transforms.CenterCrop(224) + + def preprocess_img(self, images): + images = self.resize(images) + images = self.center_crop(images) + images = self.normalize(images) + return images + + def __call__(self, text=None, images=None, **kwargs): + encoding = self.tokenizer(text=text, **kwargs) + encoding["pixel_values"] = self.preprocess_img(images) + encoding = {key: value.to(self.device) for (key, value) in encoding.items()} + return encoding + + +class VQGAN_CLIP(nn.Module): + def __init__( + self, + iterations=10, + lr=0.01, + vqgan=None, + vqgan_config=None, + vqgan_checkpoint=None, + clip=None, + clip_preprocessor=None, + device=None, + log=False, + save_vector=True, + return_val="image", + quantize=True, + save_intermediate=False, + show_intermediate=False, + make_grid=False, + ) -> None: + """ + Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan. + """ + super().__init__() + self.latent = None + self.device = device if device else get_device() + if vqgan: + self.vqgan = vqgan + else: + self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint) + self.vqgan.eval() + if clip: + self.clip = clip + else: + self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.clip.to(self.device) + self.clip_preprocessor = ProcessorGradientFlow(device=self.device) + + self.iterations = iterations + self.lr = lr + self.log = log + self.make_grid = make_grid + self.return_val = return_val + self.quantize = quantize + self.latent_dim = self.vqgan.decoder.z_shape + + def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True): + """ + Make an animation from the intermediate images saved during generation. + By default, uses the images from the most recent generation created by the generate function. + If you want to use images from a different generation, pass the path to the folder containing the images as input_path. + """ + images = [] + if output_path is None: + output_path = "./animation.gif" + if input_path is None: + input_path = self.save_path + paths = list(sorted(glob(input_path + "/*"))) + if not len(paths): + raise ValueError( + "No images found in save path, aborting (did you pass save_intermediate=True to the generate" + " function?)" + ) + if len(paths) == 1: + print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)") + frame_duration = total_duration / len(paths) + durations = [frame_duration] * len(paths) + if extend_frames: + durations[0] = 1.5 + durations[-1] = 3 + for file_name in paths: + if file_name.endswith(".png"): + images.append(imageio.imread(file_name)) + imageio.mimsave(output_path, images, duration=durations) + print(f"gif saved to {output_path}") + + def _get_latent(self, path=None, img=None): + if not (path or img): + raise ValueError("Input either path or tensor") + if img is not None: + raise NotImplementedError + x = preprocess(Image.open(path), target_image_size=256).to(self.device) + x_processed = preprocess_vqgan(x) + z, *_ = self.vqgan.encode(x_processed) + return z + + def _add_vector(self, transform_vector): + """Add a vector transform to the base latent and returns the resulting image.""" + base_latent = self.latent.detach().requires_grad_() + trans_latent = base_latent + transform_vector + if self.quantize: + z_q, *_ = self.vqgan.quantize(trans_latent) + else: + z_q = trans_latent + return self.vqgan.decode(z_q) + + def _get_clip_similarity(self, prompts, image, weights=None): + clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True) + clip_outputs = self.clip(**clip_inputs) + similarity_logits = clip_outputs.logits_per_image + if weights is not None: + similarity_logits = similarity_logits * weights + return similarity_logits.sum() + + def _get_clip_loss(self, pos_prompts, neg_prompts, image): + pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"])) + if neg_prompts: + neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"]) + else: + neg_logits = torch.tensor([1], device=self.device) + loss = -torch.log(pos_logits) + torch.log(neg_logits) + return loss + + def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts): + vector = torch.randn_like(self.latent, requires_grad=True, device=self.device) + optim = torch.optim.Adam([vector], lr=self.lr) + + for i in range(self.iterations): + optim.zero_grad() + transformed_img = self._add_vector(vector) + processed_img = loop_post_process(transformed_img) + clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img) + print("CLIP loss", clip_loss) + if self.log: + wandb.log({"CLIP Loss": clip_loss}) + clip_loss.backward(retain_graph=True) + optim.step() + if self.return_val == "image": + yield custom_to_pil(transformed_img[0]) + else: + yield vector + + def _init_logging(self, positive_prompts, negative_prompts, image_path): + wandb.init(reinit=True, project="face-editor") + wandb.config.update({"Positive Prompts": positive_prompts}) + wandb.config.update({"Negative Prompts": negative_prompts}) + wandb.config.update(dict(lr=self.lr, iterations=self.iterations)) + if image_path: + image = Image.open(image_path) + image = image.resize((256, 256)) + wandb.log("Original Image", wandb.Image(image)) + + def process_prompts(self, prompts): + if not prompts: + return [] + processed_prompts = [] + weights = [] + if isinstance(prompts, str): + prompts = [prompt.strip() for prompt in prompts.split("|")] + for prompt in prompts: + if isinstance(prompt, (tuple, list)): + processed_prompt = prompt[0] + weight = float(prompt[1]) + elif ":" in prompt: + processed_prompt, weight = prompt.split(":") + weight = float(weight) + else: + processed_prompt = prompt + weight = 1.0 + processed_prompts.append(processed_prompt) + weights.append(weight) + return { + "prompts": processed_prompts, + "weights": torch.tensor(weights, device=self.device), + } + + def generate( + self, + pos_prompts, + neg_prompts=None, + image_path=None, + show_intermediate=True, + save_intermediate=False, + show_final=True, + save_final=True, + save_path=None, + ): + """Generate an image from the given prompts. + If image_path is provided, the image is used as a starting point for the optimization. + If image_path is not provided, a random latent vector is used as a starting point. + You must provide at least one positive prompt, and optionally provide negative prompts. + Prompts must be formatted in one of the following ways: + - A single prompt as a string, e.g "A smiling woman" + - A set of prompts separated by pipes: "A smiling woman | a woman with brown hair" + - A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1) + - A list of prompts, e.g ["A smiling woman", "a woman with brown hair"] + - A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)] + """ + if image_path: + self.latent = self._get_latent(image_path) + else: + self.latent = torch.randn(self.latent_dim, device=self.device) + if self.log: + self._init_logging(pos_prompts, neg_prompts, image_path) + + assert pos_prompts, "You must provide at least one positive prompt." + pos_prompts = self.process_prompts(pos_prompts) + neg_prompts = self.process_prompts(neg_prompts) + if save_final and save_path is None: + save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"])) + if not os.path.exists(save_path): + os.makedirs(save_path) + else: + save_path = save_path + "_" + get_timestamp() + os.makedirs(save_path) + self.save_path = save_path + + original_img = self.vqgan.decode(self.latent)[0] + if show_intermediate: + print("Original Image") + show_pil(custom_to_pil(original_img)) + + original_img = loop_post_process(original_img) + for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)): + if show_intermediate: + show_pil(transformed_img) + if save_intermediate: + transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png")) + if self.log: + wandb.log({"Image": wandb.Image(transformed_img)}) + if show_final: + show_pil(transformed_img) + if save_final: + transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png")) diff --git a/examples/research_projects/vqgan-clip/img_processing.py b/examples/research_projects/vqgan-clip/img_processing.py new file mode 100644 index 00000000000..221ebd86dae --- /dev/null +++ b/examples/research_projects/vqgan-clip/img_processing.py @@ -0,0 +1,50 @@ +import numpy as np +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from PIL import Image + + +def preprocess(img, target_image_size=256): + s = min(img.size) + + if s < target_image_size: + raise ValueError(f"min dim for image {s} < {target_image_size}") + + r = target_image_size / s + s = (round(r * img.size[1]), round(r * img.size[0])) + img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) + img = TF.center_crop(img, output_size=2 * [target_image_size]) + img = torch.unsqueeze(T.ToTensor()(img), 0) + return img + + +def preprocess_vqgan(x): + x = 2.0 * x - 1.0 + return x + + +def custom_to_pil(x, process=True, mode="RGB"): + x = x.detach().cpu() + if process: + x = post_process_tensor(x) + x = x.numpy() + if process: + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == mode: + x = x.convert(mode) + return x + + +def post_process_tensor(x): + x = torch.clamp(x, -1.0, 1.0) + x = (x + 1.0) / 2.0 + x = x.permute(1, 2, 0) + return x + + +def loop_post_process(x): + x = post_process_tensor(x.squeeze()) + return x.permute(2, 0, 1).unsqueeze(0) diff --git a/examples/research_projects/vqgan-clip/loaders.py b/examples/research_projects/vqgan-clip/loaders.py new file mode 100644 index 00000000000..3fd86522dca --- /dev/null +++ b/examples/research_projects/vqgan-clip/loaders.py @@ -0,0 +1,75 @@ +import importlib + +import torch + +import yaml +from omegaconf import OmegaConf +from taming.models.vqgan import VQModel + + +def load_config(config_path, display=False): + config = OmegaConf.load(config_path) + if display: + print(yaml.dump(OmegaConf.to_container(config))) + return config + + +def load_vqgan(device, conf_path=None, ckpt_path=None): + if conf_path is None: + conf_path = "./model_checkpoints/vqgan_only.yaml" + config = load_config(conf_path, display=False) + model = VQModel(**config.model.params) + if ckpt_path is None: + ckpt_path = "./model_checkpoints/vqgan_only.pt" + sd = torch.load(ckpt_path, map_location=device) + if ".ckpt" in ckpt_path: + sd = sd["state_dict"] + model.load_state_dict(sd, strict=True) + model.to(device) + del sd + return model + + +def reconstruct_with_vqgan(x, model): + z, _, [_, _, indices] = model.encode(x) + print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}") + xrec = model.decode(z) + return xrec + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def load_model_from_config(config, sd, gpu=True, eval_mode=True): + model = instantiate_from_config(config) + if sd is not None: + model.load_state_dict(sd) + if gpu: + model.cuda() + if eval_mode: + model.eval() + return {"model": model} + + +def load_model(config, ckpt, gpu, eval_mode): + # load the specified checkpoint + if ckpt: + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + print(f"loaded model from global step {global_step}.") + else: + pl_sd = {"state_dict": None} + global_step = None + model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] + return model, global_step diff --git a/examples/research_projects/vqgan-clip/requirements.txt b/examples/research_projects/vqgan-clip/requirements.txt new file mode 100644 index 00000000000..540bac904f2 --- /dev/null +++ b/examples/research_projects/vqgan-clip/requirements.txt @@ -0,0 +1,27 @@ +einops +gradio +icecream +imageio +lpips +matplotlib +more_itertools +numpy +omegaconf +opencv_python_headless +Pillow +pudb +pytorch_lightning +PyYAML +requests +scikit_image +scipy +setuptools +streamlit +taming-transformers +torch +torchvision +tqdm +transformers==4.26.0 +tokenizers==0.13.2 +typing_extensions +wandb diff --git a/examples/research_projects/vqgan-clip/utils.py b/examples/research_projects/vqgan-clip/utils.py new file mode 100644 index 00000000000..7db45fcbb52 --- /dev/null +++ b/examples/research_projects/vqgan-clip/utils.py @@ -0,0 +1,35 @@ +from datetime import datetime + +import matplotlib.pyplot as plt +import torch + + +def freeze_module(module): + for param in module.parameters(): + param.requires_grad = False + + +def get_device(): + device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = "mps" + if device == "mps": + print( + "WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch" + " errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues" + " with generations." + ) + return device + + +def show_pil(img): + fig = plt.imshow(img) + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + plt.show() + + +def get_timestamp(): + current_time = datetime.now() + timestamp = current_time.strftime("%H:%M:%S") + return timestamp