-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_edit_art_style.py
96 lines (71 loc) · 3.28 KB
/
demo_edit_art_style.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
import torch
import pathlib
from lib.pl_utils import UnNormalize
from model_define import StyleTransfer
import torchvision.transforms as transforms
from PIL import Image
import clip
import torchvision.utils as vutils
from torchvision.transforms.functional import adjust_contrast
# Testing settings
parser = argparse.ArgumentParser(description='PyTorch TxST Example')
parser.add_argument('--content', type=str, default='data/content/14.jpg', help="content images")
parser.add_argument('--style', type=str, default='van gogh', help='text styles')
opt = parser.parse_args()
def read_content_img(img_path, img_siz=512):
transform_list = [transforms.Resize((img_siz, img_siz)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
m_transform = transforms.Compose(transform_list)
img = Image.open(img_path).convert('RGB')
img_tensor = m_transform(img)
return img_tensor.unsqueeze(0).cuda()
def read_style_img(img_path, img_siz=512):
style_img = read_content_img(img_path, img_siz)
style_name = os.path.basename(os.path.dirname(img_path))
print(style_name)
return style_img, clip.tokenize(style_name.replace("_", " "))[0].unsqueeze(0).cuda()
def custom_text(text):
return clip.tokenize(text)[0].unsqueeze(0)
# text_dic = ['Morisot', 'monet', 'munch', 'el-greco', 'kirchner', 'pollock', 'roerich', 'picasso', 'cezanne', 'gauguin', 'peploe', 'van gogh', 'kandinsky']
if __name__ == '__main__':
m_model = StyleTransfer.load_from_checkpoint(
"models/wikiart_all.ckpt", strict=False).cuda()
# you can also try wikiart_subset.ckpt for testing
I_c = read_content_img(opt.content)
# encoding
# L = len(text_dic)
# I_c = I_c.repeat(L, 1, 1, 1)
F_c = m_model.encoder(I_c)
F_clip_c = m_model.text_editor.encode_img(I_c)
# === use text inference ===
# text_input = torch.zeros(L, 77).cuda().long()
# for i in range(L):
# text_token = custom_text(text_dic[i]).cuda()
# text_input[i:i+1, :] = text_token
text_input = custom_text(opt.style).cuda()
meta = m_model.text_editor.forward(text_input)
F_clip_text = meta['raw_feat']
F_clip_text = F_clip_text.unsqueeze(1)
styled = m_model.transform(F_clip_c['raw_feat'], F_clip_text, F_c)
# decoding
I_cs = m_model.decoder(styled)
# visualize
m_unnormalize = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
content = m_unnormalize(I_c).clamp(0, 1).cpu().data
# style = m_unnormalize(I_s).clamp(0, 1).cpu().data
transfer = m_unnormalize(I_cs).clamp(0, 1).cpu().data
transfer = adjust_contrast(transfer,1.5)
# for i in range(L):
# out = transfer[i:i+1, :, :, :].squeeze(0)
# ndarr = out.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
# im = Image.fromarray(ndarr)
# save_path = 'output/' + str(i).zfill(2) + '.png'
# im.save(save_path)
out = transfer.squeeze(0)
ndarr = out.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
save_path = 'output/' + pathlib.Path(opt.content).stem + '_' + pathlib.Path(opt.style).stem + '.png'
im.save(save_path)