-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
182 lines (153 loc) · 7.79 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
from typing import Union
from fastapi import HTTPException, Depends, Request
import jwt
from constants import API_RATE_LIMIT
from slowapi import Limiter
from slowapi.util import get_remote_address
from utils.db_client import MongoDBHandler
from services.image_generation_service import ImageGenerationService
from services.user_service import SECRET_KEY, UserService
from utils.data_types import APIKey, ChangePasswordDataType, EmailDataType, Prompt, TextPrompt, TextToImage, ImageToImage, UserSigninInfo, ValidatorInfo, ChatCompletion
from utils.db_client import MongoDBHandler
def get_api_key(request: Request):
return request.headers.get("API_KEY", get_remote_address(request))
limiter = Limiter(key_func=get_api_key)
def dynamic_rate_limit(request: Request): # Add request parameter
# check if the user's role is pro or standard
# user_role = request.headers.get("User-Role", "default") # Example header
# if user_role == "pro":
# return PRO_API_RATE_LIMIT # Premium users get a higher limit
# else:
return API_RATE_LIMIT # Default limit for regular users
dbhandler = MongoDBHandler()
# verify db connection
print(dbhandler.client.server_info())
# Initialize AuthService with the dbhandler
user_service = UserService(dbhandler)
app = ImageGenerationService(dbhandler, user_service)
async def api_key_checker(request: Request = None):
client_host = request.client.host
print(client_host, flush=True)
try:
json_data = await request.json()
except Exception as e:
print(e, flush=True)
json_data = {}
api_key = request.headers.get("API_KEY") or json_data.get("key") or request.headers.get("Authorization").replace("Bearer ", "")
if not api_key or api_key not in app.dbhandler.get_auth_keys():
raise HTTPException(status_code=403, detail="Invalid or missing API key")
async def is_admin(request: Request):
token = request.headers.get('Authorization')
if not token or len(token.split(" ")) < 2: # Check if token is present and has the correct format
raise HTTPException(status_code=403, detail="Not an admin")
token = token.split(" ")[1] # Now safe to access the second element
try:
# Decode the token
data = jwt.decode(token, SECRET_KEY, algorithms=['HS256'])
except Exception as e:
print("=== exception ===>", e)
raise HTTPException(status_code=403, detail="Invalid admin token")
@app.app.post("/api/v1/txt2img", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def txt2img_api2(request: Request, data: TextToImage):
return await app.txt2img_api(request, data)
@app.app.post("/get_credentials")
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def get_credentials(request: Request, validator_info: ValidatorInfo):
return await app.get_credentials(request, validator_info)
@app.app.post("/generate", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def generate(request: Request, prompt: Union[Prompt, TextPrompt]):
return await app.generate(prompt)
@app.app.get("/get_validators", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def get_validators(request: Request):
return await app.get_validators(request)
@app.app.post("/api/v1/img2img", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def img2img_api(request: Request, data: ImageToImage):
return await app.img2img_api(request, data)
@app.app.post("/api/v1/instantid", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def instantid_api(request: Request, data: ImageToImage):
return await app.instantid_api(request, data)
@app.app.post("/api/v1/controlnet", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def controlnet_api(request: Request, data: ImageToImage):
return await app.controlnet_api(request, data)
@app.app.post("/api/v1/upscale", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def upscale_api(request: Request, data: ImageToImage):
return await app.upscale_api(request, data)
@app.app.post("/api/v1/chat/completions", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def chat_completions_api(request: Request, data: ChatCompletion):
return await app.chat_completions(request, data)
@app.app.post("/api/v1/signin")
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def signin(request: Request, data: UserSigninInfo):
user = user_service.signin(request, data)
return {"message": "User signed in successfully", "user": user}
@app.app.post("/api/v1/signup")
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def signup(request: Request, data: UserSigninInfo):
insert_result = user_service.signup(request, data)
if insert_result:
return {"message": "User created successfully", "user": insert_result}
else:
raise HTTPException(status_code=500, detail="Failed to create user")
@app.app.get("/api/v1/get_user_info", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def get_user_info(request: Request):
userInfo = user_service.get_user_info(request)
if userInfo:
return {"message": "User data fetched successfully", "user": userInfo}
else:
raise HTTPException(status_code=500, detail="Failed to fetch user data")
@app.app.get("/api/v1/add_api_key", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def add_api_key(request: Request):
apiKey = user_service.add_api_key(request)
if apiKey:
return {"message": "New api key generated", "user": apiKey}
else:
raise HTTPException(status_code=500, detail="Failed to add API key")
@app.app.post("/api/v1/delete_api_key", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def delete_api_key(request: Request, data: APIKey):
apiKey = user_service.delete_api_key(request, data.key)
if apiKey:
return {"message": "API key deleted", "user": apiKey}
else:
raise HTTPException(status_code=500, detail="Failed to delete API key")
@app.app.get("/api/v1/get_logs", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def get_logs(request: Request):
logs = user_service.get_logs(request)
if logs:
return {"message": "Retrieved Logs", "logs": logs}
else:
raise HTTPException(status_code=500, detail="Failed to get logs")
@app.app.post("/api/v1/admin/reset_password", dependencies=[Depends(is_admin)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
async def reset_password(request: Request):
return await user_service.reset_password(request)
@app.app.post("/api/v1/change_password", dependencies=[Depends(api_key_checker)])
@limiter.limit(API_RATE_LIMIT) # Update the rate limit
def change_password(request: Request, data: ChangePasswordDataType):
return user_service.change_password(request, data)
@app.app.post("/api/v1/stripe-webhook")
async def stripe_webhook(request: Request):
return await user_service.handle_webhooks(request)
@app.app.post("/api/v1/admin/signin")
async def admin_signin(request: Request):
return await user_service.admin_signin(request)
@app.app.post("/api/v1/admin/get_users", dependencies=[Depends(is_admin)])
def get_users(request: Request):
users = user_service.admin_get_users(request)
return {"users": users}
@app.app.post("/api/v1/admin/delete_user", dependencies=[Depends(is_admin)])
async def delete_user(request: Request):
result = await user_service.admin_delete_user(request)
return result