diff --git a/examples/research_projects/vqgan-clip/README.md b/examples/research_projects/vqgan-clip/README.md
new file mode 100644
index 00000000000..aef95093542
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/README.md
@@ -0,0 +1,70 @@
+# Simple VQGAN CLIP
+
+Author: @ErwannMillon
+
+This is a very simple VQGAN-CLIP implementation that was built as a part of the Face Editor project . This simplified version allows you to generate or edit images using text with just three lines of code. For a more full featured implementation with masking, more advanced losses, and a full GUI, check out the Face Editor project.
+
+By default this uses a CelebA checkpoint (for generating/editing faces), but also has an imagenet checkpoint that can be loaded by specifying vqgan_config and vqgan_checkpoint when instantiating VQGAN_CLIP.
+
+Learning rate and iterations can be set by modifying vqgan_clip.lr and vqgan_clip.iterations .
+
+You can edit images by passing `image_path` to the generate function.
+See the generate function's docstring to learn more about how to format prompts.
+
+## Usage
+The easiest way to test this out is by using the Colab demo
+
+To install locally:
+- Clone this repo
+- Install git-lfs (ubuntu: sudo apt-get install git-lfs , MacOS: brew install git-lfs)
+
+In the root of the repo run:
+
+```
+conda create -n vqganclip python=3.8
+conda activate vqganclip
+git-lfs install
+git clone https://huggingface.co/datasets/erwann/face_editor_model_ckpt model_checkpoints
+pip install -r requirements.txt
+```
+
+### Generate new images
+```
+from VQGAN_CLIP import VQGAN_CLIP
+vqgan_clip = VQGAN_CLIP()
+vqgan_clip.generate("a picture of a smiling woman")
+```
+
+### Edit an image
+To get a test image, run
+`git clone https://huggingface.co/datasets/erwann/vqgan-clip-pic test_images`
+
+To edit:
+```
+from VQGAN_CLIP import VQGAN_CLIP
+vqgan_clip = VQGAN_CLIP()
+
+vqgan_clip.lr = .07
+vqgan_clip.iterations = 15
+vqgan_clip.generate(
+ pos_prompts= ["a picture of a beautiful asian woman", "a picture of a woman from Japan"],
+ neg_prompts=["a picture of an Indian person", "a picture of a white person"],
+ image_path="./test_images/face.jpeg",
+ show_intermediate=True,
+ save_intermediate=True,
+)
+```
+
+### Make an animation from the most recent generation
+`vqgan_clip.make_animation()`
+
+## Features:
+- Positive and negative prompts
+- Multiple prompts
+- Prompt Weights
+- Creating GIF animations of the transformations
+- Wandb logging
+
+
+
diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py
new file mode 100644
index 00000000000..c936148c554
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py
@@ -0,0 +1,268 @@
+import os
+from glob import glob
+
+import torch
+import torchvision
+from PIL import Image
+from torch import nn
+
+import imageio
+import wandb
+from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
+from loaders import load_vqgan
+from transformers import CLIPModel, CLIPTokenizerFast
+from utils import get_device, get_timestamp, show_pil
+
+
+class ProcessorGradientFlow:
+ """
+ This wraps the huggingface CLIP processor to allow backprop through the image processing step.
+ The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow.
+ We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors.
+ """
+
+ def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None:
+ self.device = device
+ self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model)
+ self.image_mean = [0.48145466, 0.4578275, 0.40821073]
+ self.image_std = [0.26862954, 0.26130258, 0.27577711]
+ self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std)
+ self.resize = torchvision.transforms.Resize(224)
+ self.center_crop = torchvision.transforms.CenterCrop(224)
+
+ def preprocess_img(self, images):
+ images = self.resize(images)
+ images = self.center_crop(images)
+ images = self.normalize(images)
+ return images
+
+ def __call__(self, text=None, images=None, **kwargs):
+ encoding = self.tokenizer(text=text, **kwargs)
+ encoding["pixel_values"] = self.preprocess_img(images)
+ encoding = {key: value.to(self.device) for (key, value) in encoding.items()}
+ return encoding
+
+
+class VQGAN_CLIP(nn.Module):
+ def __init__(
+ self,
+ iterations=10,
+ lr=0.01,
+ vqgan=None,
+ vqgan_config=None,
+ vqgan_checkpoint=None,
+ clip=None,
+ clip_preprocessor=None,
+ device=None,
+ log=False,
+ save_vector=True,
+ return_val="image",
+ quantize=True,
+ save_intermediate=False,
+ show_intermediate=False,
+ make_grid=False,
+ ) -> None:
+ """
+ Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan.
+ """
+ super().__init__()
+ self.latent = None
+ self.device = device if device else get_device()
+ if vqgan:
+ self.vqgan = vqgan
+ else:
+ self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint)
+ self.vqgan.eval()
+ if clip:
+ self.clip = clip
+ else:
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ self.clip.to(self.device)
+ self.clip_preprocessor = ProcessorGradientFlow(device=self.device)
+
+ self.iterations = iterations
+ self.lr = lr
+ self.log = log
+ self.make_grid = make_grid
+ self.return_val = return_val
+ self.quantize = quantize
+ self.latent_dim = self.vqgan.decoder.z_shape
+
+ def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True):
+ """
+ Make an animation from the intermediate images saved during generation.
+ By default, uses the images from the most recent generation created by the generate function.
+ If you want to use images from a different generation, pass the path to the folder containing the images as input_path.
+ """
+ images = []
+ if output_path is None:
+ output_path = "./animation.gif"
+ if input_path is None:
+ input_path = self.save_path
+ paths = list(sorted(glob(input_path + "/*")))
+ if not len(paths):
+ raise ValueError(
+ "No images found in save path, aborting (did you pass save_intermediate=True to the generate"
+ " function?)"
+ )
+ if len(paths) == 1:
+ print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)")
+ frame_duration = total_duration / len(paths)
+ durations = [frame_duration] * len(paths)
+ if extend_frames:
+ durations[0] = 1.5
+ durations[-1] = 3
+ for file_name in paths:
+ if file_name.endswith(".png"):
+ images.append(imageio.imread(file_name))
+ imageio.mimsave(output_path, images, duration=durations)
+ print(f"gif saved to {output_path}")
+
+ def _get_latent(self, path=None, img=None):
+ if not (path or img):
+ raise ValueError("Input either path or tensor")
+ if img is not None:
+ raise NotImplementedError
+ x = preprocess(Image.open(path), target_image_size=256).to(self.device)
+ x_processed = preprocess_vqgan(x)
+ z, *_ = self.vqgan.encode(x_processed)
+ return z
+
+ def _add_vector(self, transform_vector):
+ """Add a vector transform to the base latent and returns the resulting image."""
+ base_latent = self.latent.detach().requires_grad_()
+ trans_latent = base_latent + transform_vector
+ if self.quantize:
+ z_q, *_ = self.vqgan.quantize(trans_latent)
+ else:
+ z_q = trans_latent
+ return self.vqgan.decode(z_q)
+
+ def _get_clip_similarity(self, prompts, image, weights=None):
+ clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True)
+ clip_outputs = self.clip(**clip_inputs)
+ similarity_logits = clip_outputs.logits_per_image
+ if weights is not None:
+ similarity_logits = similarity_logits * weights
+ return similarity_logits.sum()
+
+ def _get_clip_loss(self, pos_prompts, neg_prompts, image):
+ pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"]))
+ if neg_prompts:
+ neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"])
+ else:
+ neg_logits = torch.tensor([1], device=self.device)
+ loss = -torch.log(pos_logits) + torch.log(neg_logits)
+ return loss
+
+ def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts):
+ vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
+ optim = torch.optim.Adam([vector], lr=self.lr)
+
+ for i in range(self.iterations):
+ optim.zero_grad()
+ transformed_img = self._add_vector(vector)
+ processed_img = loop_post_process(transformed_img)
+ clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img)
+ print("CLIP loss", clip_loss)
+ if self.log:
+ wandb.log({"CLIP Loss": clip_loss})
+ clip_loss.backward(retain_graph=True)
+ optim.step()
+ if self.return_val == "image":
+ yield custom_to_pil(transformed_img[0])
+ else:
+ yield vector
+
+ def _init_logging(self, positive_prompts, negative_prompts, image_path):
+ wandb.init(reinit=True, project="face-editor")
+ wandb.config.update({"Positive Prompts": positive_prompts})
+ wandb.config.update({"Negative Prompts": negative_prompts})
+ wandb.config.update(dict(lr=self.lr, iterations=self.iterations))
+ if image_path:
+ image = Image.open(image_path)
+ image = image.resize((256, 256))
+ wandb.log("Original Image", wandb.Image(image))
+
+ def process_prompts(self, prompts):
+ if not prompts:
+ return []
+ processed_prompts = []
+ weights = []
+ if isinstance(prompts, str):
+ prompts = [prompt.strip() for prompt in prompts.split("|")]
+ for prompt in prompts:
+ if isinstance(prompt, (tuple, list)):
+ processed_prompt = prompt[0]
+ weight = float(prompt[1])
+ elif ":" in prompt:
+ processed_prompt, weight = prompt.split(":")
+ weight = float(weight)
+ else:
+ processed_prompt = prompt
+ weight = 1.0
+ processed_prompts.append(processed_prompt)
+ weights.append(weight)
+ return {
+ "prompts": processed_prompts,
+ "weights": torch.tensor(weights, device=self.device),
+ }
+
+ def generate(
+ self,
+ pos_prompts,
+ neg_prompts=None,
+ image_path=None,
+ show_intermediate=True,
+ save_intermediate=False,
+ show_final=True,
+ save_final=True,
+ save_path=None,
+ ):
+ """Generate an image from the given prompts.
+ If image_path is provided, the image is used as a starting point for the optimization.
+ If image_path is not provided, a random latent vector is used as a starting point.
+ You must provide at least one positive prompt, and optionally provide negative prompts.
+ Prompts must be formatted in one of the following ways:
+ - A single prompt as a string, e.g "A smiling woman"
+ - A set of prompts separated by pipes: "A smiling woman | a woman with brown hair"
+ - A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1)
+ - A list of prompts, e.g ["A smiling woman", "a woman with brown hair"]
+ - A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)]
+ """
+ if image_path:
+ self.latent = self._get_latent(image_path)
+ else:
+ self.latent = torch.randn(self.latent_dim, device=self.device)
+ if self.log:
+ self._init_logging(pos_prompts, neg_prompts, image_path)
+
+ assert pos_prompts, "You must provide at least one positive prompt."
+ pos_prompts = self.process_prompts(pos_prompts)
+ neg_prompts = self.process_prompts(neg_prompts)
+ if save_final and save_path is None:
+ save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"]))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ else:
+ save_path = save_path + "_" + get_timestamp()
+ os.makedirs(save_path)
+ self.save_path = save_path
+
+ original_img = self.vqgan.decode(self.latent)[0]
+ if show_intermediate:
+ print("Original Image")
+ show_pil(custom_to_pil(original_img))
+
+ original_img = loop_post_process(original_img)
+ for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)):
+ if show_intermediate:
+ show_pil(transformed_img)
+ if save_intermediate:
+ transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png"))
+ if self.log:
+ wandb.log({"Image": wandb.Image(transformed_img)})
+ if show_final:
+ show_pil(transformed_img)
+ if save_final:
+ transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png"))
diff --git a/examples/research_projects/vqgan-clip/img_processing.py b/examples/research_projects/vqgan-clip/img_processing.py
new file mode 100644
index 00000000000..221ebd86dae
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/img_processing.py
@@ -0,0 +1,50 @@
+import numpy as np
+import PIL
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+from PIL import Image
+
+
+def preprocess(img, target_image_size=256):
+ s = min(img.size)
+
+ if s < target_image_size:
+ raise ValueError(f"min dim for image {s} < {target_image_size}")
+
+ r = target_image_size / s
+ s = (round(r * img.size[1]), round(r * img.size[0]))
+ img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
+ img = TF.center_crop(img, output_size=2 * [target_image_size])
+ img = torch.unsqueeze(T.ToTensor()(img), 0)
+ return img
+
+
+def preprocess_vqgan(x):
+ x = 2.0 * x - 1.0
+ return x
+
+
+def custom_to_pil(x, process=True, mode="RGB"):
+ x = x.detach().cpu()
+ if process:
+ x = post_process_tensor(x)
+ x = x.numpy()
+ if process:
+ x = (255 * x).astype(np.uint8)
+ x = Image.fromarray(x)
+ if not x.mode == mode:
+ x = x.convert(mode)
+ return x
+
+
+def post_process_tensor(x):
+ x = torch.clamp(x, -1.0, 1.0)
+ x = (x + 1.0) / 2.0
+ x = x.permute(1, 2, 0)
+ return x
+
+
+def loop_post_process(x):
+ x = post_process_tensor(x.squeeze())
+ return x.permute(2, 0, 1).unsqueeze(0)
diff --git a/examples/research_projects/vqgan-clip/loaders.py b/examples/research_projects/vqgan-clip/loaders.py
new file mode 100644
index 00000000000..3fd86522dca
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/loaders.py
@@ -0,0 +1,75 @@
+import importlib
+
+import torch
+
+import yaml
+from omegaconf import OmegaConf
+from taming.models.vqgan import VQModel
+
+
+def load_config(config_path, display=False):
+ config = OmegaConf.load(config_path)
+ if display:
+ print(yaml.dump(OmegaConf.to_container(config)))
+ return config
+
+
+def load_vqgan(device, conf_path=None, ckpt_path=None):
+ if conf_path is None:
+ conf_path = "./model_checkpoints/vqgan_only.yaml"
+ config = load_config(conf_path, display=False)
+ model = VQModel(**config.model.params)
+ if ckpt_path is None:
+ ckpt_path = "./model_checkpoints/vqgan_only.pt"
+ sd = torch.load(ckpt_path, map_location=device)
+ if ".ckpt" in ckpt_path:
+ sd = sd["state_dict"]
+ model.load_state_dict(sd, strict=True)
+ model.to(device)
+ del sd
+ return model
+
+
+def reconstruct_with_vqgan(x, model):
+ z, _, [_, _, indices] = model.encode(x)
+ print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
+ xrec = model.decode(z)
+ return xrec
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def instantiate_from_config(config):
+ if "target" not in config:
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def load_model_from_config(config, sd, gpu=True, eval_mode=True):
+ model = instantiate_from_config(config)
+ if sd is not None:
+ model.load_state_dict(sd)
+ if gpu:
+ model.cuda()
+ if eval_mode:
+ model.eval()
+ return {"model": model}
+
+
+def load_model(config, ckpt, gpu, eval_mode):
+ # load the specified checkpoint
+ if ckpt:
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ print(f"loaded model from global step {global_step}.")
+ else:
+ pl_sd = {"state_dict": None}
+ global_step = None
+ model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
+ return model, global_step
diff --git a/examples/research_projects/vqgan-clip/requirements.txt b/examples/research_projects/vqgan-clip/requirements.txt
new file mode 100644
index 00000000000..540bac904f2
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/requirements.txt
@@ -0,0 +1,27 @@
+einops
+gradio
+icecream
+imageio
+lpips
+matplotlib
+more_itertools
+numpy
+omegaconf
+opencv_python_headless
+Pillow
+pudb
+pytorch_lightning
+PyYAML
+requests
+scikit_image
+scipy
+setuptools
+streamlit
+taming-transformers
+torch
+torchvision
+tqdm
+transformers==4.26.0
+tokenizers==0.13.2
+typing_extensions
+wandb
diff --git a/examples/research_projects/vqgan-clip/utils.py b/examples/research_projects/vqgan-clip/utils.py
new file mode 100644
index 00000000000..7db45fcbb52
--- /dev/null
+++ b/examples/research_projects/vqgan-clip/utils.py
@@ -0,0 +1,35 @@
+from datetime import datetime
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def freeze_module(module):
+ for param in module.parameters():
+ param.requires_grad = False
+
+
+def get_device():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ device = "mps"
+ if device == "mps":
+ print(
+ "WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch"
+ " errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues"
+ " with generations."
+ )
+ return device
+
+
+def show_pil(img):
+ fig = plt.imshow(img)
+ fig.axes.get_xaxis().set_visible(False)
+ fig.axes.get_yaxis().set_visible(False)
+ plt.show()
+
+
+def get_timestamp():
+ current_time = datetime.now()
+ timestamp = current_time.strftime("%H:%M:%S")
+ return timestamp