forked from vastsa/FileCodeBox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
119 lines (102 loc) · 3.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import datetime
import os
import uuid
from fastapi import FastAPI, Depends, UploadFile, Form, File
from sqlalchemy.orm import Session
from starlette.requests import Request
from starlette.responses import HTMLResponse
import random
from starlette.staticfiles import StaticFiles
import database
from database import engine, SessionLocal, Base
Base.metadata.create_all(bind=engine)
app = FastAPI()
if not os.path.exists('./static'):
os.makedirs('./static')
app.mount("/static", StaticFiles(directory="static"), name="static")
index_html = open('templates/index.html', 'r').read()
# 过期时间
exp_hour = 24
# 允许错误次数
error_count = 5
# 禁止分钟数
error_minute = 60
error_ip_count = {}
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_code(db: Session = Depends(get_db)):
code = random.randint(10000, 99999)
while db.query(database.Codes).filter(database.Codes.code == code).first():
code = random.randint(10000, 99999)
return str(code)
def get_file_name(key, ext, file):
now = datetime.datetime.now()
path = f'./static/upload/{now.year}/{now.month}/{now.day}/'
name = f'{key}.{ext}'
if not os.path.exists(path):
os.makedirs(path)
file = file.file.read()
with open(f'{os.path.join(path, name)}', 'wb') as f:
f.write(file)
return key, len(file), path[1:] + name
@app.get('/')
async def index():
with open('templates/index.html', 'r') as f:
return HTMLResponse(index_html)
@app.post('/')
async def index(request: Request, code: str, db: Session = Depends(get_db)):
info = db.query(database.Codes).filter(database.Codes.code == code).first()
error = error_ip_count.get(request.client.host, {'count': 0, 'time': datetime.datetime.now()})
if error['count'] > error_count:
if datetime.datetime.now() - error['time'] < datetime.timedelta(minutes=error_minute):
return {'code': 404, 'msg': '请求过于频繁,请稍后再试'}
else:
error['count'] = 0
else:
if not info:
error['count'] += 1
error_ip_count[request.client.host] = error
return {'code': 404, 'msg': f'取件码错误,错误5次将被禁止10分钟'}
else:
return {'code': 200, 'msg': '取件成功,请点击库查看', 'data': info}
@app.post('/share')
async def share(text: str = Form(default=None), file: UploadFile = File(default=None), db: Session = Depends(get_db)):
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=exp_hour)
db.query(database.Codes).filter(database.Codes.use_time < cutoff_time).delete()
db.commit()
code = get_code(db)
if text:
info = database.Codes(
code=code,
text=text,
type='text/plain',
key=uuid.uuid4().hex,
size=len(text),
used=True,
name='分享文本'
)
db.add(info)
db.commit()
return {'code': 200, 'msg': '上传成功,请点击文件库查看',
'data': {'code': code, 'name': '分享文本', 'text': text}}
elif file:
key, size, full_path = get_file_name(uuid.uuid4().hex, file.filename.split('.')[-1], file)
info = database.Codes(
code=code,
text=full_path,
type=file.content_type,
key=key,
size=size,
used=True,
name=file.filename
)
db.add(info)
db.commit()
return {'code': 200, 'msg': '上传成功,请点击文件库查看',
'data': {'code': code, 'name': file.filename, 'text': full_path}}
else:
return {'code': 422, 'msg': '参数错误', 'data': []}