mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add VQGAN-CLIP research project (#21329)
* Add VQGAN-CLIP research project * fixed style issues * Update examples/research_projects/vqgan-clip/README.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/requirements.txt Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/README.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/research_projects/vqgan-clip/loaders.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * replace CLIPProcessor with tokenizer, change asserts to exceptions * rm unused import * remove large files (jupyter notebook linked in readme, imgs migrated to hf dataset) * add tokenizers dependency * Remove comment Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * rm model checkpoints --------- Co-authored-by: Erwann Millon <erwann@Erwanns-MacBook-Air.local> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
fbee82951f
commit
ea55bd86b9
70
examples/research_projects/vqgan-clip/README.md
Normal file
70
examples/research_projects/vqgan-clip/README.md
Normal file
@ -0,0 +1,70 @@
|
||||
# Simple VQGAN CLIP
|
||||
|
||||
Author: @ErwannMillon
|
||||
|
||||
This is a very simple VQGAN-CLIP implementation that was built as a part of the <a href= "https://github.com/ErwannMillon/face-editor"> Face Editor project </a> . This simplified version allows you to generate or edit images using text with just three lines of code. For a more full featured implementation with masking, more advanced losses, and a full GUI, check out the Face Editor project.
|
||||
|
||||
By default this uses a CelebA checkpoint (for generating/editing faces), but also has an imagenet checkpoint that can be loaded by specifying vqgan_config and vqgan_checkpoint when instantiating VQGAN_CLIP.
|
||||
|
||||
Learning rate and iterations can be set by modifying vqgan_clip.lr and vqgan_clip.iterations .
|
||||
|
||||
You can edit images by passing `image_path` to the generate function.
|
||||
See the generate function's docstring to learn more about how to format prompts.
|
||||
|
||||
## Usage
|
||||
The easiest way to test this out is by <a href="https://colab.research.google.com/drive/1Ez4D1J6-hVkmlXeR5jBPWYyu6CLA9Yor?usp=sharing
|
||||
">using the Colab demo</a>
|
||||
|
||||
To install locally:
|
||||
- Clone this repo
|
||||
- Install git-lfs (ubuntu: sudo apt-get install git-lfs , MacOS: brew install git-lfs)
|
||||
|
||||
In the root of the repo run:
|
||||
|
||||
```
|
||||
conda create -n vqganclip python=3.8
|
||||
conda activate vqganclip
|
||||
git-lfs install
|
||||
git clone https://huggingface.co/datasets/erwann/face_editor_model_ckpt model_checkpoints
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Generate new images
|
||||
```
|
||||
from VQGAN_CLIP import VQGAN_CLIP
|
||||
vqgan_clip = VQGAN_CLIP()
|
||||
vqgan_clip.generate("a picture of a smiling woman")
|
||||
```
|
||||
|
||||
### Edit an image
|
||||
To get a test image, run
|
||||
`git clone https://huggingface.co/datasets/erwann/vqgan-clip-pic test_images`
|
||||
|
||||
To edit:
|
||||
```
|
||||
from VQGAN_CLIP import VQGAN_CLIP
|
||||
vqgan_clip = VQGAN_CLIP()
|
||||
|
||||
vqgan_clip.lr = .07
|
||||
vqgan_clip.iterations = 15
|
||||
vqgan_clip.generate(
|
||||
pos_prompts= ["a picture of a beautiful asian woman", "a picture of a woman from Japan"],
|
||||
neg_prompts=["a picture of an Indian person", "a picture of a white person"],
|
||||
image_path="./test_images/face.jpeg",
|
||||
show_intermediate=True,
|
||||
save_intermediate=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Make an animation from the most recent generation
|
||||
`vqgan_clip.make_animation()`
|
||||
|
||||
## Features:
|
||||
- Positive and negative prompts
|
||||
- Multiple prompts
|
||||
- Prompt Weights
|
||||
- Creating GIF animations of the transformations
|
||||
- Wandb logging
|
||||
|
||||
|
||||
|
268
examples/research_projects/vqgan-clip/VQGAN_CLIP.py
Normal file
268
examples/research_projects/vqgan-clip/VQGAN_CLIP.py
Normal file
@ -0,0 +1,268 @@
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
|
||||
import imageio
|
||||
import wandb
|
||||
from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
|
||||
from loaders import load_vqgan
|
||||
from transformers import CLIPModel, CLIPTokenizerFast
|
||||
from utils import get_device, get_timestamp, show_pil
|
||||
|
||||
|
||||
class ProcessorGradientFlow:
|
||||
"""
|
||||
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
||||
The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow.
|
||||
We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None:
|
||||
self.device = device
|
||||
self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model)
|
||||
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std)
|
||||
self.resize = torchvision.transforms.Resize(224)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(224)
|
||||
|
||||
def preprocess_img(self, images):
|
||||
images = self.resize(images)
|
||||
images = self.center_crop(images)
|
||||
images = self.normalize(images)
|
||||
return images
|
||||
|
||||
def __call__(self, text=None, images=None, **kwargs):
|
||||
encoding = self.tokenizer(text=text, **kwargs)
|
||||
encoding["pixel_values"] = self.preprocess_img(images)
|
||||
encoding = {key: value.to(self.device) for (key, value) in encoding.items()}
|
||||
return encoding
|
||||
|
||||
|
||||
class VQGAN_CLIP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
iterations=10,
|
||||
lr=0.01,
|
||||
vqgan=None,
|
||||
vqgan_config=None,
|
||||
vqgan_checkpoint=None,
|
||||
clip=None,
|
||||
clip_preprocessor=None,
|
||||
device=None,
|
||||
log=False,
|
||||
save_vector=True,
|
||||
return_val="image",
|
||||
quantize=True,
|
||||
save_intermediate=False,
|
||||
show_intermediate=False,
|
||||
make_grid=False,
|
||||
) -> None:
|
||||
"""
|
||||
Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan.
|
||||
"""
|
||||
super().__init__()
|
||||
self.latent = None
|
||||
self.device = device if device else get_device()
|
||||
if vqgan:
|
||||
self.vqgan = vqgan
|
||||
else:
|
||||
self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint)
|
||||
self.vqgan.eval()
|
||||
if clip:
|
||||
self.clip = clip
|
||||
else:
|
||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip.to(self.device)
|
||||
self.clip_preprocessor = ProcessorGradientFlow(device=self.device)
|
||||
|
||||
self.iterations = iterations
|
||||
self.lr = lr
|
||||
self.log = log
|
||||
self.make_grid = make_grid
|
||||
self.return_val = return_val
|
||||
self.quantize = quantize
|
||||
self.latent_dim = self.vqgan.decoder.z_shape
|
||||
|
||||
def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True):
|
||||
"""
|
||||
Make an animation from the intermediate images saved during generation.
|
||||
By default, uses the images from the most recent generation created by the generate function.
|
||||
If you want to use images from a different generation, pass the path to the folder containing the images as input_path.
|
||||
"""
|
||||
images = []
|
||||
if output_path is None:
|
||||
output_path = "./animation.gif"
|
||||
if input_path is None:
|
||||
input_path = self.save_path
|
||||
paths = list(sorted(glob(input_path + "/*")))
|
||||
if not len(paths):
|
||||
raise ValueError(
|
||||
"No images found in save path, aborting (did you pass save_intermediate=True to the generate"
|
||||
" function?)"
|
||||
)
|
||||
if len(paths) == 1:
|
||||
print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)")
|
||||
frame_duration = total_duration / len(paths)
|
||||
durations = [frame_duration] * len(paths)
|
||||
if extend_frames:
|
||||
durations[0] = 1.5
|
||||
durations[-1] = 3
|
||||
for file_name in paths:
|
||||
if file_name.endswith(".png"):
|
||||
images.append(imageio.imread(file_name))
|
||||
imageio.mimsave(output_path, images, duration=durations)
|
||||
print(f"gif saved to {output_path}")
|
||||
|
||||
def _get_latent(self, path=None, img=None):
|
||||
if not (path or img):
|
||||
raise ValueError("Input either path or tensor")
|
||||
if img is not None:
|
||||
raise NotImplementedError
|
||||
x = preprocess(Image.open(path), target_image_size=256).to(self.device)
|
||||
x_processed = preprocess_vqgan(x)
|
||||
z, *_ = self.vqgan.encode(x_processed)
|
||||
return z
|
||||
|
||||
def _add_vector(self, transform_vector):
|
||||
"""Add a vector transform to the base latent and returns the resulting image."""
|
||||
base_latent = self.latent.detach().requires_grad_()
|
||||
trans_latent = base_latent + transform_vector
|
||||
if self.quantize:
|
||||
z_q, *_ = self.vqgan.quantize(trans_latent)
|
||||
else:
|
||||
z_q = trans_latent
|
||||
return self.vqgan.decode(z_q)
|
||||
|
||||
def _get_clip_similarity(self, prompts, image, weights=None):
|
||||
clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True)
|
||||
clip_outputs = self.clip(**clip_inputs)
|
||||
similarity_logits = clip_outputs.logits_per_image
|
||||
if weights is not None:
|
||||
similarity_logits = similarity_logits * weights
|
||||
return similarity_logits.sum()
|
||||
|
||||
def _get_clip_loss(self, pos_prompts, neg_prompts, image):
|
||||
pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"]))
|
||||
if neg_prompts:
|
||||
neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"])
|
||||
else:
|
||||
neg_logits = torch.tensor([1], device=self.device)
|
||||
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
||||
return loss
|
||||
|
||||
def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts):
|
||||
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
||||
optim = torch.optim.Adam([vector], lr=self.lr)
|
||||
|
||||
for i in range(self.iterations):
|
||||
optim.zero_grad()
|
||||
transformed_img = self._add_vector(vector)
|
||||
processed_img = loop_post_process(transformed_img)
|
||||
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img)
|
||||
print("CLIP loss", clip_loss)
|
||||
if self.log:
|
||||
wandb.log({"CLIP Loss": clip_loss})
|
||||
clip_loss.backward(retain_graph=True)
|
||||
optim.step()
|
||||
if self.return_val == "image":
|
||||
yield custom_to_pil(transformed_img[0])
|
||||
else:
|
||||
yield vector
|
||||
|
||||
def _init_logging(self, positive_prompts, negative_prompts, image_path):
|
||||
wandb.init(reinit=True, project="face-editor")
|
||||
wandb.config.update({"Positive Prompts": positive_prompts})
|
||||
wandb.config.update({"Negative Prompts": negative_prompts})
|
||||
wandb.config.update(dict(lr=self.lr, iterations=self.iterations))
|
||||
if image_path:
|
||||
image = Image.open(image_path)
|
||||
image = image.resize((256, 256))
|
||||
wandb.log("Original Image", wandb.Image(image))
|
||||
|
||||
def process_prompts(self, prompts):
|
||||
if not prompts:
|
||||
return []
|
||||
processed_prompts = []
|
||||
weights = []
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompt.strip() for prompt in prompts.split("|")]
|
||||
for prompt in prompts:
|
||||
if isinstance(prompt, (tuple, list)):
|
||||
processed_prompt = prompt[0]
|
||||
weight = float(prompt[1])
|
||||
elif ":" in prompt:
|
||||
processed_prompt, weight = prompt.split(":")
|
||||
weight = float(weight)
|
||||
else:
|
||||
processed_prompt = prompt
|
||||
weight = 1.0
|
||||
processed_prompts.append(processed_prompt)
|
||||
weights.append(weight)
|
||||
return {
|
||||
"prompts": processed_prompts,
|
||||
"weights": torch.tensor(weights, device=self.device),
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
pos_prompts,
|
||||
neg_prompts=None,
|
||||
image_path=None,
|
||||
show_intermediate=True,
|
||||
save_intermediate=False,
|
||||
show_final=True,
|
||||
save_final=True,
|
||||
save_path=None,
|
||||
):
|
||||
"""Generate an image from the given prompts.
|
||||
If image_path is provided, the image is used as a starting point for the optimization.
|
||||
If image_path is not provided, a random latent vector is used as a starting point.
|
||||
You must provide at least one positive prompt, and optionally provide negative prompts.
|
||||
Prompts must be formatted in one of the following ways:
|
||||
- A single prompt as a string, e.g "A smiling woman"
|
||||
- A set of prompts separated by pipes: "A smiling woman | a woman with brown hair"
|
||||
- A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1)
|
||||
- A list of prompts, e.g ["A smiling woman", "a woman with brown hair"]
|
||||
- A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)]
|
||||
"""
|
||||
if image_path:
|
||||
self.latent = self._get_latent(image_path)
|
||||
else:
|
||||
self.latent = torch.randn(self.latent_dim, device=self.device)
|
||||
if self.log:
|
||||
self._init_logging(pos_prompts, neg_prompts, image_path)
|
||||
|
||||
assert pos_prompts, "You must provide at least one positive prompt."
|
||||
pos_prompts = self.process_prompts(pos_prompts)
|
||||
neg_prompts = self.process_prompts(neg_prompts)
|
||||
if save_final and save_path is None:
|
||||
save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"]))
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
else:
|
||||
save_path = save_path + "_" + get_timestamp()
|
||||
os.makedirs(save_path)
|
||||
self.save_path = save_path
|
||||
|
||||
original_img = self.vqgan.decode(self.latent)[0]
|
||||
if show_intermediate:
|
||||
print("Original Image")
|
||||
show_pil(custom_to_pil(original_img))
|
||||
|
||||
original_img = loop_post_process(original_img)
|
||||
for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)):
|
||||
if show_intermediate:
|
||||
show_pil(transformed_img)
|
||||
if save_intermediate:
|
||||
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png"))
|
||||
if self.log:
|
||||
wandb.log({"Image": wandb.Image(transformed_img)})
|
||||
if show_final:
|
||||
show_pil(transformed_img)
|
||||
if save_final:
|
||||
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png"))
|
50
examples/research_projects/vqgan-clip/img_processing.py
Normal file
50
examples/research_projects/vqgan-clip/img_processing.py
Normal file
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def preprocess(img, target_image_size=256):
|
||||
s = min(img.size)
|
||||
|
||||
if s < target_image_size:
|
||||
raise ValueError(f"min dim for image {s} < {target_image_size}")
|
||||
|
||||
r = target_image_size / s
|
||||
s = (round(r * img.size[1]), round(r * img.size[0]))
|
||||
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
||||
img = TF.center_crop(img, output_size=2 * [target_image_size])
|
||||
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
||||
return img
|
||||
|
||||
|
||||
def preprocess_vqgan(x):
|
||||
x = 2.0 * x - 1.0
|
||||
return x
|
||||
|
||||
|
||||
def custom_to_pil(x, process=True, mode="RGB"):
|
||||
x = x.detach().cpu()
|
||||
if process:
|
||||
x = post_process_tensor(x)
|
||||
x = x.numpy()
|
||||
if process:
|
||||
x = (255 * x).astype(np.uint8)
|
||||
x = Image.fromarray(x)
|
||||
if not x.mode == mode:
|
||||
x = x.convert(mode)
|
||||
return x
|
||||
|
||||
|
||||
def post_process_tensor(x):
|
||||
x = torch.clamp(x, -1.0, 1.0)
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.permute(1, 2, 0)
|
||||
return x
|
||||
|
||||
|
||||
def loop_post_process(x):
|
||||
x = post_process_tensor(x.squeeze())
|
||||
return x.permute(2, 0, 1).unsqueeze(0)
|
75
examples/research_projects/vqgan-clip/loaders.py
Normal file
75
examples/research_projects/vqgan-clip/loaders.py
Normal file
@ -0,0 +1,75 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from taming.models.vqgan import VQModel
|
||||
|
||||
|
||||
def load_config(config_path, display=False):
|
||||
config = OmegaConf.load(config_path)
|
||||
if display:
|
||||
print(yaml.dump(OmegaConf.to_container(config)))
|
||||
return config
|
||||
|
||||
|
||||
def load_vqgan(device, conf_path=None, ckpt_path=None):
|
||||
if conf_path is None:
|
||||
conf_path = "./model_checkpoints/vqgan_only.yaml"
|
||||
config = load_config(conf_path, display=False)
|
||||
model = VQModel(**config.model.params)
|
||||
if ckpt_path is None:
|
||||
ckpt_path = "./model_checkpoints/vqgan_only.pt"
|
||||
sd = torch.load(ckpt_path, map_location=device)
|
||||
if ".ckpt" in ckpt_path:
|
||||
sd = sd["state_dict"]
|
||||
model.load_state_dict(sd, strict=True)
|
||||
model.to(device)
|
||||
del sd
|
||||
return model
|
||||
|
||||
|
||||
def reconstruct_with_vqgan(x, model):
|
||||
z, _, [_, _, indices] = model.encode(x)
|
||||
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
|
||||
xrec = model.decode(z)
|
||||
return xrec
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if "target" not in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
||||
model = instantiate_from_config(config)
|
||||
if sd is not None:
|
||||
model.load_state_dict(sd)
|
||||
if gpu:
|
||||
model.cuda()
|
||||
if eval_mode:
|
||||
model.eval()
|
||||
return {"model": model}
|
||||
|
||||
|
||||
def load_model(config, ckpt, gpu, eval_mode):
|
||||
# load the specified checkpoint
|
||||
if ckpt:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
print(f"loaded model from global step {global_step}.")
|
||||
else:
|
||||
pl_sd = {"state_dict": None}
|
||||
global_step = None
|
||||
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
|
||||
return model, global_step
|
27
examples/research_projects/vqgan-clip/requirements.txt
Normal file
27
examples/research_projects/vqgan-clip/requirements.txt
Normal file
@ -0,0 +1,27 @@
|
||||
einops
|
||||
gradio
|
||||
icecream
|
||||
imageio
|
||||
lpips
|
||||
matplotlib
|
||||
more_itertools
|
||||
numpy
|
||||
omegaconf
|
||||
opencv_python_headless
|
||||
Pillow
|
||||
pudb
|
||||
pytorch_lightning
|
||||
PyYAML
|
||||
requests
|
||||
scikit_image
|
||||
scipy
|
||||
setuptools
|
||||
streamlit
|
||||
taming-transformers
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
transformers==4.26.0
|
||||
tokenizers==0.13.2
|
||||
typing_extensions
|
||||
wandb
|
35
examples/research_projects/vqgan-clip/utils.py
Normal file
35
examples/research_projects/vqgan-clip/utils.py
Normal file
@ -0,0 +1,35 @@
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
|
||||
def freeze_module(module):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def get_device():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = "mps"
|
||||
if device == "mps":
|
||||
print(
|
||||
"WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch"
|
||||
" errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues"
|
||||
" with generations."
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
def show_pil(img):
|
||||
fig = plt.imshow(img)
|
||||
fig.axes.get_xaxis().set_visible(False)
|
||||
fig.axes.get_yaxis().set_visible(False)
|
||||
plt.show()
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
current_time = datetime.now()
|
||||
timestamp = current_time.strftime("%H:%M:%S")
|
||||
return timestamp
|
Loading…
Reference in New Issue
Block a user