-
Notifications
You must be signed in to change notification settings - Fork 8
/
comfyui_default.py
90 lines (65 loc) · 2.34 KB
/
comfyui_default.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import os
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from torchvision.transforms import InterpolationMode
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
def register_node(identifier: str, display_name: str):
def decorator(cls):
NODE_CLASS_MAPPINGS[identifier] = cls
NODE_DISPLAY_NAME_MAPPINGS[identifier] = display_name
return cls
return decorator
def load_image(path, convert="RGB"):
img = Image.open(path).convert(convert)
img = np.array(img).astype(np.float32) / 255.0
img = torch.from_numpy(img).unsqueeze(0)
return img
def save_image(img: torch.Tensor, path, prompt=None, extra_pnginfo: dict = None):
path = str(path)
if len(img.shape) != 3:
raise ValueError(f"can't take image batch as input, got {img.shape[0]} images")
img = img.permute(2, 0, 1)
if img.shape[0] != 3:
raise ValueError(f"image must have 3 channels, but got {img.shape[0]} channels")
img = img.clamp(0, 1)
img = F.to_pil_image(img)
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for k, v in extra_pnginfo.items():
metadata.add_text(k, json.dumps(v))
img.save(path, pnginfo=metadata, compress_level=4)
subfolder, filename = os.path.split(path)
return {"filename": filename, "subfolder": subfolder, "type": "output"}
@register_node("JWImageLoadRGBIfExists", "Image Load RGB If Exists")
class _:
CATEGORY = "jamesWalker55"
INPUT_TYPES = lambda: {
"required": {
"default": ("IMAGE",),
"path": ("STRING", {"default": "./image.png"}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
def execute(self, path: str, default: torch.Tensor):
assert isinstance(path, str)
assert isinstance(default, torch.Tensor)
if not os.path.exists(path):
return (default,)
img = load_image(path)
return (img,)
@classmethod
def IS_CHANGED(cls, path: str, default: torch.Tensor):
if os.path.exists(path):
mtime = os.path.getmtime(path)
else:
mtime = None
return (mtime, default.__hash__())