transformers/examples/research_projects/vqgan-clip/utils.py
Erwann Millon ea55bd86b9
Add VQGAN-CLIP research project (#21329)
* Add VQGAN-CLIP research project

* fixed style issues

* Update examples/research_projects/vqgan-clip/README.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/requirements.txt

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/README.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/VQGAN_CLIP.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/research_projects/vqgan-clip/loaders.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* replace CLIPProcessor with tokenizer, change asserts to exceptions

* rm unused import

* remove large files (jupyter notebook linked in readme, imgs migrated to hf dataset)

* add tokenizers dependency

* Remove comment

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* rm model checkpoints

---------

Co-authored-by: Erwann Millon <erwann@Erwanns-MacBook-Air.local>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2023-02-02 14:45:35 -05:00

36 lines
969 B
Python

from datetime import datetime
import matplotlib.pyplot as plt
import torch
def freeze_module(module):
for param in module.parameters():
param.requires_grad = False
def get_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
if device == "mps":
print(
"WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch"
" errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues"
" with generations."
)
return device
def show_pil(img):
fig = plt.imshow(img)
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.show()
def get_timestamp():
current_time = datetime.now()
timestamp = current_time.strftime("%H:%M:%S")
return timestamp