Skip to content

Commit

Permalink
feat: 支持core的系统级别gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
feng626 committed Dec 1, 2023
1 parent 99691d2 commit b678efb
Show file tree
Hide file tree
Showing 15 changed files with 466 additions and 606 deletions.
44 changes: 25 additions & 19 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,29 @@ go 1.20
require (
github.com/dlclark/regexp2 v1.10.0
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/jumpserver/wisp v0.1.15
github.com/google/uuid v1.4.0
github.com/gorilla/websocket v1.5.1
github.com/jumpserver/wisp v0.1.16
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible
github.com/sashabaranov/go-openai v1.14.1
github.com/spf13/viper v1.16.0
github.com/sashabaranov/go-openai v1.17.9
github.com/spf13/viper v1.17.0
go.uber.org/zap v1.24.0
google.golang.org/grpc v1.57.0
google.golang.org/grpc v1.59.0
)

require (
github.com/bytedance/sonic v1.10.0-rc2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/chenzhuoyu/iasm v0.9.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
Expand All @@ -38,25 +39,30 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.4.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230814215434-ca7cfce7776a // indirect
golang.org/x/crypto v0.16.0 // indirect
golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/jumpserver/wisp => /Users/xiaofeng/Desktop/wisp/
517 changes: 47 additions & 470 deletions go.sum

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions pkg/httpd/middlewares/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package middlewares

import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/jumpserver/kael/pkg/jms"
"github.com/jumpserver/kael/pkg/logger"
"go.uber.org/zap"
"net/http"
"net/url"
)

func AuthMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
reqCookies := ctx.Request.Cookies()
checkUserHandler := jms.NewCheckUserHandler()
user, err := checkUserHandler.CheckUserByCookies(reqCookies)
if err != nil {
logger.GlobalLogger.Error("Check user cookie failed", zap.Error(err))
loginUrl := fmt.Sprintf("/core/auth/login/?next=%s", url.QueryEscape(ctx.Request.URL.RequestURI()))
ctx.Redirect(http.StatusFound, loginUrl)
ctx.Abort()
return
}
ctx.Set("CONTEXT_USER", user)
}
}
200 changes: 114 additions & 86 deletions pkg/httpd/router/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ import (
"github.com/gorilla/websocket"
"github.com/jumpserver/kael/pkg/httpd/ws"
"github.com/jumpserver/kael/pkg/jms"
"github.com/jumpserver/kael/pkg/logger"
"github.com/jumpserver/kael/pkg/manager"
"github.com/jumpserver/kael/pkg/schemas"
"github.com/jumpserver/wisp/protobuf-go/protobuf"
"github.com/sashabaranov/go-openai"
"go.uber.org/zap"
"net/http"
"time"
)
Expand All @@ -24,13 +21,15 @@ type _ChatApi struct{}
func (s *_ChatApi) ChatHandler(ctx *gin.Context) {
conn, err := ws.UpgradeWsConn(ctx)
if err != nil {
logger.GlobalLogger.Error("Websocket upgrade err", zap.Error(err))
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "websocket upgrade failed"})
return
}

defer conn.Close()

token, ok := ctx.GetQuery("token")
if !ok {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "token"})
ctx.JSON(http.StatusBadRequest, gin.H{"error": "token required"})
return
}

Expand All @@ -39,115 +38,144 @@ func (s *_ChatApi) ChatHandler(ctx *gin.Context) {
sessionHandler := jms.NewSessionHandler(conn)
authInfo, err := tokenHandler.GetTokenAuthInfo(token)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "auth fail"})
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

defer conn.Close()

for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.GlobalLogger.Info("Accept message error or connect closed")
if len(currentJMSS) != 0 {
for _, jmss := range currentJMSS {
reason := "Websocket已关闭, 会话中断"
jmss.Close(reason)
}
messageType, msg, err := conn.ReadMessage()
if err != nil && len(currentJMSS) != 0 {
for _, jmss := range currentJMSS {
jmss.Close("Websocket已关闭, 会话中断")
}
return
}

if string(msg) == "ping" {
_ = conn.WriteMessage(websocket.TextMessage, []byte("pong"))
continue
}
if messageType == websocket.TextMessage {
if string(msg) == "ping" {
_ = conn.WriteMessage(websocket.TextMessage, []byte("pong"))
return
}

var askRequest schemas.AskRequest
_ = json.Unmarshal(msg, &askRequest)

var askRequest schemas.AskRequest
_ = json.Unmarshal(msg, &askRequest)
jmss := &jms.JMSSession{}
if askRequest.ConversationID == "" {
jmss = sessionHandler.CreateNewSession(authInfo)
jmss.ActiveSession()
currentJMSS = append(currentJMSS, jmss)
} else {
jmss := &jms.JMSSession{}
conversationID := askRequest.ConversationID
jmss = jms.GlobalSessionManager.GetJMSSession(conversationID)
if jmss == nil {
response := schemas.AskResponse{
Type: schemas.Error,
ConversationID: askRequest.ConversationID,
SystemMessage: "current session not found",
}
jsonResponse, _ := json.Marshal(response)
_ = conn.WriteMessage(websocket.TextMessage, jsonResponse)
continue
if conversationID == "" {
jmss = sessionHandler.CreateNewSession(authInfo)
jmss.ActiveSession()
currentJMSS = append(currentJMSS, jmss)
} else {
jmss.JMSState.NewDialogue = true
jmss, err = jms.GlobalSessionManager.GetSession(conversationID)
if err != nil {
sendErrorMessage(conn, "current session not found", conversationID)
return
} else {
jmss.JMSState.NewDialogue = true
}
}

chatGPTParam := &manager.ChatGPTParam{
AuthToken: authInfo.Account.Secret,
BaseURL: authInfo.Asset.Address,
Proxy: authInfo.Asset.Specific.HttpProxy,
Model: authInfo.Platform.Protocols[0].Settings["api_mode"],
}

id := jmss.GetID()
wsConn := jmss.Websocket
currentAskInterrupt := &jmss.CurrentAskInterrupt
jmss.HistoryAsks = append(jmss.HistoryAsks, askRequest.Content)
go jmss.WithAudit(
askRequest.Content,
chatFunc(
chatGPTParam, jmss.HistoryAsks,
id, wsConn, currentAskInterrupt,
),
)
}
go jmss.WithAudit(askRequest.Content, chatFunc(authInfo, askRequest))
}
}

func chatFunc(authInfo *protobuf.TokenAuthInfo, askRequest schemas.AskRequest) func(jmss *jms.JMSSession) string {
return func(jmss *jms.JMSSession) string {
func chatFunc(
chatGPTParam *manager.ChatGPTParam, historyAsks []string,
id string, wsConn *websocket.Conn, currentAskInterrupt *bool,
) func() string {
return func() string {
doneCh := make(chan string)
answerCh := make(chan string)

model := authInfo.Platform.Protocols[0].Settings["api_mode"]
jmss.HistoryAsks = append(jmss.HistoryAsks, askRequest.Content)
defer close(doneCh)
defer close(answerCh)

c := manager.NewClient(
authInfo.Account.Secret,
authInfo.Asset.Address,
authInfo.Asset.Specific.HttpProxy,
chatGPTParam.AuthToken,
chatGPTParam.BaseURL,
chatGPTParam.Proxy,
)

askChatGPT := &manager.AskChatGPT{
Client: c,
Model: model,
Contents: jmss.HistoryAsks,
Model: chatGPTParam.Model,
Contents: historyAsks,
AnswerCh: answerCh,
DoneCh: doneCh,
}

go manager.ChatGPT(askChatGPT, jmss)
messageID := uuid.New()
for {
select {
case answer := <-answerCh:
response := schemas.AskResponse{
Type: schemas.Message,
ConversationID: jmss.Session.Id,
Message: &schemas.ChatGPTMessage{
Content: answer,
ID: messageID,
CreateTime: time.Now(),
Type: schemas.Message,
Role: openai.ChatMessageRoleAssistant,
},
}
jsonResponse, _ := json.Marshal(response)
_ = jmss.Websocket.WriteMessage(websocket.TextMessage, jsonResponse)
case answer := <-doneCh:
response := schemas.AskResponse{
Type: schemas.Message,
ConversationID: jmss.Session.Id,
Message: &schemas.ChatGPTMessage{
Content: answer,
ID: messageID,
CreateTime: time.Now(),
Type: schemas.Finish,
Role: openai.ChatMessageRoleAssistant,
},
}
jsonResponse, _ := json.Marshal(response)
_ = jmss.Websocket.WriteMessage(websocket.TextMessage, jsonResponse)
close(doneCh)
close(answerCh)
return answer
}
go manager.ChatGPT(askChatGPT)
return processChatMessages(currentAskInterrupt, id, answerCh, doneCh, wsConn)
}
}

func sendErrorMessage(conn *websocket.Conn, message, conversationID string) {
response := schemas.AskResponse{
Type: schemas.Error,
ConversationID: conversationID,
SystemMessage: message,
}
jsonResponse, _ := json.Marshal(response)
_ = conn.WriteMessage(websocket.TextMessage, jsonResponse)
}

func processChatMessages(
currentAskInterrupt *bool, id string,
answerCh <-chan string, doneCh <-chan string, wsConn *websocket.Conn,
) string {
content := ""
messageID := uuid.New()

for {
select {
case answer := <-answerCh:
content = answer
sendChatResponse(id, wsConn, answer, messageID, schemas.Message)
case answer := <-doneCh:
content = answer
sendChatResponse(id, wsConn, answer, messageID, schemas.Finish)
return answer
}

if *currentAskInterrupt {
*currentAskInterrupt = false
return content
}
}
}

func sendChatResponse(
id string, ws *websocket.Conn, chatContent string,
messageID uuid.UUID, messageType schemas.AskResponseType) {
response := schemas.AskResponse{
Type: schemas.Message,
ConversationID: id,
Message: &schemas.ChatGPTMessage{
Content: chatContent,
ID: messageID,
CreateTime: time.Now(),
Type: messageType,
Role: openai.ChatMessageRoleAssistant,
},
}
jsonResponse, _ := json.Marshal(response)
_ = ws.WriteMessage(websocket.TextMessage, jsonResponse)
}
8 changes: 4 additions & 4 deletions pkg/httpd/router/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ var HandlerApi = new(_HandlerApi)

type _HandlerApi struct{}

func getJmsSession(sessionID string) (*jms.JMSSession, error) {
jmsSession := jms.GlobalSessionManager.GetJMSSession(sessionID)
func GetSession(sessionID string) (*jms.JMSSession, error) {
jmsSession, _ := jms.GlobalSessionManager.GetSession(sessionID)
if jmsSession != nil {
return jmsSession, nil
}
Expand All @@ -27,7 +27,7 @@ func (s *_HandlerApi) InterruptCurrentAskHandler(ctx *gin.Context) {
return
}

jmsSession, err := getJmsSession(conversation.ID)
jmsSession, err := GetSession(conversation.ID)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"message": err.Error()})
return
Expand All @@ -42,7 +42,7 @@ func (s *_HandlerApi) JmsStateHandler(ctx *gin.Context) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid data"})
return
}
jmsSession, err := getJmsSession(jmsState.ID)
jmsSession, err := GetSession(jmsState.ID)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"message": err.Error()})
return
Expand Down
Loading

0 comments on commit b678efb

Please sign in to comment.