forked from Akegarasu/lora-scripts
-
Notifications
You must be signed in to change notification settings - Fork 8
/
gui.py
94 lines (77 loc) · 2.9 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import json
import os
import subprocess
import sys
import webbrowser
from datetime import datetime
from threading import Lock
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import toml
app = FastAPI()
lock = Lock()
# fix mimetype error in some fucking systems
sf = StaticFiles(directory="frontend/dist")
_o_fr = sf.file_response
def _hooked_file_response(*args, **kwargs):
full_path = args[0]
r = _o_fr(*args, **kwargs)
if full_path.endswith(".js"):
r.media_type = "application/javascript"
elif full_path.endswith(".css"):
r.media_type = "text/css"
return r
sf.file_response = _hooked_file_response
parser = argparse.ArgumentParser(description="GUI for training network")
parser.add_argument("--port", type=int, default=28000, help="Port to run the server on")
def run_train(toml_path: str):
print(f"Training started with config file / 训练开始,使用配置文件: {toml_path}")
args = [
"accelerate", "launch", "--num_cpu_threads_per_process", "8",
"./sd-scripts/train_network.py",
"--config_file", toml_path,
]
try:
result = subprocess.run(args, shell=True, env=os.environ)
if result.returncode != 0:
print(f"Training failed / 训练失败")
else:
print(f"Training finished / 训练完成")
except Exception as e:
print(f"An error occurred when training / 创建训练进程时出现致命错误: {e}")
finally:
lock.release()
@app.post("/api/run")
async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
acquired = lock.acquire(blocking=False)
if not acquired:
print("Training is already running / 已有正在进行的训练")
return {"status": "fail", "detail": "Training is already running"}
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
toml_file = f"toml/{timestamp}.toml"
toml_data = await request.body()
j = json.loads(toml_data.decode("utf-8"))
with open(toml_file, "w") as f:
f.write(toml.dumps(j))
background_tasks.add_task(run_train, toml_file)
return {"status": "success"}
@app.middleware("http")
async def add_cache_control_header(request, call_next):
response = await call_next(request)
response.headers["Cache-Control"] = "max-age=0"
return response
@app.get("/")
async def index():
return FileResponse("./frontend/dist/index.html")
app.mount("/", sf, name="static")
if __name__ == "__main__":
args, _ = parser.parse_known_args()
print(f"Server started at http://127.0.0.1:{args.port}")
if sys.platform == "win32":
# disable triton on windows
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
webbrowser.open(f"http://127.0.0.1:{args.port}")
uvicorn.run(app, host="127.0.0.1", port=28000, log_level="error")