Add VQGAN-CLIP research project (#21329)

* Add VQGAN-CLIP research project

* fixed style issues

* Update examples/research_projects/vqgan-clip/README.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/requirements.txt

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/README.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/loaders.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* replace CLIPProcessor with tokenizer, change asserts to exceptions

* rm unused import

* remove large files (jupyter notebook linked in readme, imgs migrated to hf dataset)

* add tokenizers dependency

* Remove comment

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* rm model checkpoints

---------

Co-authored-by: Erwann Millon <erwann@Erwanns-MacBook-Air.local>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Erwann Millon 2023-02-02 14:45:35 -05:00 committed by GitHub
parent fbee82951f
commit ea55bd86b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 525 additions and 0 deletions

View File

@ -0,0 +1,70 @@
# Simple VQGAN CLIP
Author: @ErwannMillon
This is a very simple VQGAN-CLIP implementation that was built as a part of the <a href= "https://github.com/ErwannMillon/face-editor"> Face Editor project </a> . This simplified version allows you to generate or edit images using text with just three lines of code. For a more full featured implementation with masking, more advanced losses, and a full GUI, check out the Face Editor project.
By default this uses a CelebA checkpoint (for generating/editing faces), but also has an imagenet checkpoint that can be loaded by specifying vqgan_config and vqgan_checkpoint when instantiating VQGAN_CLIP.
Learning rate and iterations can be set by modifying vqgan_clip.lr and vqgan_clip.iterations .
You can edit images by passing `image_path` to the generate function.
See the generate function's docstring to learn more about how to format prompts.
## Usage
The easiest way to test this out is by <a href="https://colab.research.google.com/drive/1Ez4D1J6-hVkmlXeR5jBPWYyu6CLA9Yor?usp=sharing
">using the Colab demo</a>
To install locally:
- Clone this repo
- Install git-lfs (ubuntu: sudo apt-get install git-lfs , MacOS: brew install git-lfs)
In the root of the repo run:
```
conda create -n vqganclip python=3.8
conda activate vqganclip
git-lfs install
git clone https://huggingface.co/datasets/erwann/face_editor_model_ckpt model_checkpoints
pip install -r requirements.txt
```
### Generate new images
```
from VQGAN_CLIP import VQGAN_CLIP
vqgan_clip = VQGAN_CLIP()
vqgan_clip.generate("a picture of a smiling woman")
```
### Edit an image
To get a test image, run
`git clone https://huggingface.co/datasets/erwann/vqgan-clip-pic test_images`
To edit:
```
from VQGAN_CLIP import VQGAN_CLIP
vqgan_clip = VQGAN_CLIP()
vqgan_clip.lr = .07
vqgan_clip.iterations = 15
vqgan_clip.generate(
pos_prompts= ["a picture of a beautiful asian woman", "a picture of a woman from Japan"],
neg_prompts=["a picture of an Indian person", "a picture of a white person"],
image_path="./test_images/face.jpeg",
show_intermediate=True,
save_intermediate=True,
)
```
### Make an animation from the most recent generation
`vqgan_clip.make_animation()`
## Features:
- Positive and negative prompts
- Multiple prompts
- Prompt Weights
- Creating GIF animations of the transformations
- Wandb logging

View File

@ -0,0 +1,268 @@
import os
from glob import glob
import torch
import torchvision
from PIL import Image
from torch import nn
import imageio
import wandb
from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
from loaders import load_vqgan
from transformers import CLIPModel, CLIPTokenizerFast
from utils import get_device, get_timestamp, show_pil
class ProcessorGradientFlow:
"""
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow.
We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors.
"""
def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None:
self.device = device
self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model)
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
self.image_std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std)
self.resize = torchvision.transforms.Resize(224)
self.center_crop = torchvision.transforms.CenterCrop(224)
def preprocess_img(self, images):
images = self.resize(images)
images = self.center_crop(images)
images = self.normalize(images)
return images
def __call__(self, text=None, images=None, **kwargs):
encoding = self.tokenizer(text=text, **kwargs)
encoding["pixel_values"] = self.preprocess_img(images)
encoding = {key: value.to(self.device) for (key, value) in encoding.items()}
return encoding
class VQGAN_CLIP(nn.Module):
def __init__(
self,
iterations=10,
lr=0.01,
vqgan=None,
vqgan_config=None,
vqgan_checkpoint=None,
clip=None,
clip_preprocessor=None,
device=None,
log=False,
save_vector=True,
return_val="image",
quantize=True,
save_intermediate=False,
show_intermediate=False,
make_grid=False,
) -> None:
"""
Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan.
"""
super().__init__()
self.latent = None
self.device = device if device else get_device()
if vqgan:
self.vqgan = vqgan
else:
self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint)
self.vqgan.eval()
if clip:
self.clip = clip
else:
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip.to(self.device)
self.clip_preprocessor = ProcessorGradientFlow(device=self.device)
self.iterations = iterations
self.lr = lr
self.log = log
self.make_grid = make_grid
self.return_val = return_val
self.quantize = quantize
self.latent_dim = self.vqgan.decoder.z_shape
def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True):
"""
Make an animation from the intermediate images saved during generation.
By default, uses the images from the most recent generation created by the generate function.
If you want to use images from a different generation, pass the path to the folder containing the images as input_path.
"""
images = []
if output_path is None:
output_path = "./animation.gif"
if input_path is None:
input_path = self.save_path
paths = list(sorted(glob(input_path + "/*")))
if not len(paths):
raise ValueError(
"No images found in save path, aborting (did you pass save_intermediate=True to the generate"
" function?)"
)
if len(paths) == 1:
print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)")
frame_duration = total_duration / len(paths)
durations = [frame_duration] * len(paths)
if extend_frames:
durations[0] = 1.5
durations[-1] = 3
for file_name in paths:
if file_name.endswith(".png"):
images.append(imageio.imread(file_name))
imageio.mimsave(output_path, images, duration=durations)
print(f"gif saved to {output_path}")
def _get_latent(self, path=None, img=None):
if not (path or img):
raise ValueError("Input either path or tensor")
if img is not None:
raise NotImplementedError
x = preprocess(Image.open(path), target_image_size=256).to(self.device)
x_processed = preprocess_vqgan(x)
z, *_ = self.vqgan.encode(x_processed)
return z
def _add_vector(self, transform_vector):
"""Add a vector transform to the base latent and returns the resulting image."""
base_latent = self.latent.detach().requires_grad_()
trans_latent = base_latent + transform_vector
if self.quantize:
z_q, *_ = self.vqgan.quantize(trans_latent)
else:
z_q = trans_latent
return self.vqgan.decode(z_q)
def _get_clip_similarity(self, prompts, image, weights=None):
clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True)
clip_outputs = self.clip(**clip_inputs)
similarity_logits = clip_outputs.logits_per_image
if weights is not None:
similarity_logits = similarity_logits * weights
return similarity_logits.sum()
def _get_clip_loss(self, pos_prompts, neg_prompts, image):
pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"]))
if neg_prompts:
neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"])
else:
neg_logits = torch.tensor([1], device=self.device)
loss = -torch.log(pos_logits) + torch.log(neg_logits)
return loss
def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts):
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
optim = torch.optim.Adam([vector], lr=self.lr)
for i in range(self.iterations):
optim.zero_grad()
transformed_img = self._add_vector(vector)
processed_img = loop_post_process(transformed_img)
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img)
print("CLIP loss", clip_loss)
if self.log:
wandb.log({"CLIP Loss": clip_loss})
clip_loss.backward(retain_graph=True)
optim.step()
if self.return_val == "image":
yield custom_to_pil(transformed_img[0])
else:
yield vector
def _init_logging(self, positive_prompts, negative_prompts, image_path):
wandb.init(reinit=True, project="face-editor")
wandb.config.update({"Positive Prompts": positive_prompts})
wandb.config.update({"Negative Prompts": negative_prompts})
wandb.config.update(dict(lr=self.lr, iterations=self.iterations))
if image_path:
image = Image.open(image_path)
image = image.resize((256, 256))
wandb.log("Original Image", wandb.Image(image))
def process_prompts(self, prompts):
if not prompts:
return []
processed_prompts = []
weights = []
if isinstance(prompts, str):
prompts = [prompt.strip() for prompt in prompts.split("|")]
for prompt in prompts:
if isinstance(prompt, (tuple, list)):
processed_prompt = prompt[0]
weight = float(prompt[1])
elif ":" in prompt:
processed_prompt, weight = prompt.split(":")
weight = float(weight)
else:
processed_prompt = prompt
weight = 1.0
processed_prompts.append(processed_prompt)
weights.append(weight)
return {
"prompts": processed_prompts,
"weights": torch.tensor(weights, device=self.device),
}
def generate(
self,
pos_prompts,
neg_prompts=None,
image_path=None,
show_intermediate=True,
save_intermediate=False,
show_final=True,
save_final=True,
save_path=None,
):
"""Generate an image from the given prompts.
If image_path is provided, the image is used as a starting point for the optimization.
If image_path is not provided, a random latent vector is used as a starting point.
You must provide at least one positive prompt, and optionally provide negative prompts.
Prompts must be formatted in one of the following ways:
- A single prompt as a string, e.g "A smiling woman"
- A set of prompts separated by pipes: "A smiling woman | a woman with brown hair"
- A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1)
- A list of prompts, e.g ["A smiling woman", "a woman with brown hair"]
- A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)]
"""
if image_path:
self.latent = self._get_latent(image_path)
else:
self.latent = torch.randn(self.latent_dim, device=self.device)
if self.log:
self._init_logging(pos_prompts, neg_prompts, image_path)
assert pos_prompts, "You must provide at least one positive prompt."
pos_prompts = self.process_prompts(pos_prompts)
neg_prompts = self.process_prompts(neg_prompts)
if save_final and save_path is None:
save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"]))
if not os.path.exists(save_path):
os.makedirs(save_path)
else:
save_path = save_path + "_" + get_timestamp()
os.makedirs(save_path)
self.save_path = save_path
original_img = self.vqgan.decode(self.latent)[0]
if show_intermediate:
print("Original Image")
show_pil(custom_to_pil(original_img))
original_img = loop_post_process(original_img)
for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)):
if show_intermediate:
show_pil(transformed_img)
if save_intermediate:
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png"))
if self.log:
wandb.log({"Image": wandb.Image(transformed_img)})
if show_final:
show_pil(transformed_img)
if save_final:
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png"))

View File

@ -0,0 +1,50 @@
import numpy as np
import PIL
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image
def preprocess(img, target_image_size=256):
s = min(img.size)
if s < target_image_size:
raise ValueError(f"min dim for image {s} < {target_image_size}")
r = target_image_size / s
s = (round(r * img.size[1]), round(r * img.size[0]))
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
img = TF.center_crop(img, output_size=2 * [target_image_size])
img = torch.unsqueeze(T.ToTensor()(img), 0)
return img
def preprocess_vqgan(x):
x = 2.0 * x - 1.0
return x
def custom_to_pil(x, process=True, mode="RGB"):
x = x.detach().cpu()
if process:
x = post_process_tensor(x)
x = x.numpy()
if process:
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == mode:
x = x.convert(mode)
return x
def post_process_tensor(x):
x = torch.clamp(x, -1.0, 1.0)
x = (x + 1.0) / 2.0
x = x.permute(1, 2, 0)
return x
def loop_post_process(x):
x = post_process_tensor(x.squeeze())
return x.permute(2, 0, 1).unsqueeze(0)

View File

@ -0,0 +1,75 @@
import importlib
import torch
import yaml
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
def load_vqgan(device, conf_path=None, ckpt_path=None):
if conf_path is None:
conf_path = "./model_checkpoints/vqgan_only.yaml"
config = load_config(conf_path, display=False)
model = VQModel(**config.model.params)
if ckpt_path is None:
ckpt_path = "./model_checkpoints/vqgan_only.pt"
sd = torch.load(ckpt_path, map_location=device)
if ".ckpt" in ckpt_path:
sd = sd["state_dict"]
model.load_state_dict(sd, strict=True)
model.to(device)
del sd
return model
def reconstruct_with_vqgan(x, model):
z, _, [_, _, indices] = model.encode(x)
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
xrec = model.decode(z)
return xrec
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if "target" not in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
model.load_state_dict(sd)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_model(config, ckpt, gpu, eval_mode):
# load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
print(f"loaded model from global step {global_step}.")
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
return model, global_step

View File

@ -0,0 +1,27 @@
einops
gradio
icecream
imageio
lpips
matplotlib
more_itertools
numpy
omegaconf
opencv_python_headless
Pillow
pudb
pytorch_lightning
PyYAML
requests
scikit_image
scipy
setuptools
streamlit
taming-transformers
torch
torchvision
tqdm
transformers==4.26.0
tokenizers==0.13.2
typing_extensions
wandb

View File

@ -0,0 +1,35 @@
from datetime import datetime
import matplotlib.pyplot as plt
import torch
def freeze_module(module):
for param in module.parameters():
param.requires_grad = False
def get_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
if device == "mps":
print(
"WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch"
" errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues"
" with generations."
)
return device
def show_pil(img):
fig = plt.imshow(img)
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.show()
def get_timestamp():
current_time = datetime.now()
timestamp = current_time.strftime("%H:%M:%S")
return timestamp