forked from Jrohy/webssh
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
113 lines (106 loc) · 2.66 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package main
import (
"embed"
"flag"
"fmt"
"io/fs"
"net/http"
"os"
"strconv"
"strings"
"time"
"webssh/controller"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
)
//go:embed web/dist/*
var f embed.FS
var (
port = flag.Int("p", 5032, "服务运行端口")
v = flag.Bool("v", false, "显示版本号")
authInfo = flag.String("a", "", "开启账号密码登录验证, '-a user:pass'的格式传参")
timeout int
savePass bool
version string
buildDate string
goVersion string
gitVersion string
username string
password string
)
func init() {
flag.IntVar(&timeout, "t", 120, "ssh连接超时时间(min)")
flag.BoolVar(&savePass, "s", false, "保存ssh密码")
envVal, ok := os.LookupEnv("savePass")
if ok {
b, err := strconv.ParseBool(envVal)
if err != nil {
savePass = false
} else {
savePass = b
}
}
flag.Parse()
if *v {
fmt.Printf("Version: %s\n\n", version)
fmt.Printf("BuildDate: %s\n\n", buildDate)
fmt.Printf("GoVersion: %s\n\n", goVersion)
fmt.Printf("GitVersion: %s\n\n", gitVersion)
os.Exit(0)
}
if *authInfo != "" {
accountInfo := strings.Split(*authInfo, ":")
if len(accountInfo) != 2 || accountInfo[0] == "" || accountInfo[1] == "" {
fmt.Println("请按'-a user:pass'的格式来传参, 且账号密码都不能为空!")
os.Exit(0)
}
username, password = accountInfo[0], accountInfo[1]
}
}
func staticRouter(router *gin.Engine) {
if password != "" {
accountList := map[string]string{
username: password,
}
authorized := router.Group("/", gin.BasicAuth(accountList))
authorized.GET("", func(c *gin.Context) {
indexHTML, _ := f.ReadFile("web/dist/" + "index.html")
c.Writer.Write(indexHTML)
})
} else {
router.GET("/", func(c *gin.Context) {
indexHTML, _ := f.ReadFile("web/dist/" + "index.html")
c.Writer.Write(indexHTML)
})
}
staticFs, _ := fs.Sub(f, "web/dist/static")
router.StaticFS("/static", http.FS(staticFs))
}
func main() {
server := gin.Default()
server.Use(gzip.Gzip(gzip.DefaultCompression))
staticRouter(server)
server.GET("/term", func(c *gin.Context) {
controller.TermWs(c, time.Duration(timeout)*time.Minute)
})
server.GET("/check", func(c *gin.Context) {
responseBody := controller.CheckSSH(c)
responseBody.Data = map[string]interface{}{
"savePass": savePass,
}
c.JSON(200, responseBody)
})
file := server.Group("/file")
{
file.GET("/list", func(c *gin.Context) {
c.JSON(200, controller.FileList(c))
})
file.GET("/download", func(c *gin.Context) {
controller.DownloadFile(c)
})
file.POST("/upload", func(c *gin.Context) {
c.JSON(200, controller.UploadFile(c))
})
}
server.Run(fmt.Sprintf(":%d", *port))
}