Skip to content

Commit

Permalink
feat: 添加prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
feng626 committed Dec 13, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent afde383 commit 60b16fb
Showing 7 changed files with 23 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pkg/httpd/router/chat.go
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@ func (s *_ChatApi) ChatHandler(ctx *gin.Context) {
jmss := &jms.JMSSession{}
conversationID := askRequest.ConversationID
if conversationID == "" {
jmss = sessionHandler.CreateNewSession(authInfo)
jmss = sessionHandler.CreateNewSession(authInfo, askRequest.Prompt)
jmss.ActiveSession()
currentJMSS = append(currentJMSS, jmss)
} else {
@@ -83,6 +83,7 @@ func (s *_ChatApi) ChatHandler(ctx *gin.Context) {
BaseURL: authInfo.Asset.Address,
Proxy: authInfo.Asset.Specific.HttpProxy,
Model: authInfo.Platform.Protocols[0].Settings["api_mode"],
Prompt: jmss.Prompt,
}

id := jmss.GetID()
@@ -124,7 +125,7 @@ func chatFunc(
DoneCh: doneCh,
}

go manager.ChatGPT(askChatGPT)
go manager.ChatGPT(askChatGPT, chatGPTParam.Prompt)
return processChatMessages(currentAskInterrupt, id, answerCh, doneCh, wsConn)
}
}
3 changes: 2 additions & 1 deletion pkg/httpd/router/system_chat.go
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ func (s *_SystemChatApi) ChatHandler(ctx *gin.Context) {
jmss := &jms.JMSSystemSession{}
conversationID := askRequest.ConversationID
if conversationID == "" {
jmss = jmss.CreateNewSession(conn)
jmss = jmss.CreateNewSession(conn, askRequest.Prompt)
jmss.ActiveSession()
currentJMSS = append(currentJMSS, jmss)
} else {
@@ -71,6 +71,7 @@ func (s *_SystemChatApi) ChatHandler(ctx *gin.Context) {
BaseURL: publicSetting.GptBaseUrl,
Proxy: publicSetting.GptProxy,
Model: publicSetting.GptModel,
Prompt: jmss.Prompt,
}
id := jmss.GetID()
wsConn := jmss.Websocket
5 changes: 4 additions & 1 deletion pkg/jms/jms_session.go
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ import (
type JMSSession struct {
Session *protobuf.Session
Websocket *websocket.Conn
Prompt string
HistoryAsks []string
CurrentAskInterrupt bool
CommandACLs []*protobuf.CommandACL
@@ -113,6 +114,7 @@ func (jmss *JMSSession) WithAudit(command string, chatFunc func() string) (resul

type JMSSystemSession struct {
Id string
Prompt string
HistoryAsks []string
CurrentAskInterrupt bool
Websocket *websocket.Conn
@@ -123,12 +125,13 @@ func (sh *JMSSystemSession) GetID() string {
return sh.Id
}

func (sh *JMSSystemSession) CreateNewSession(conn *websocket.Conn) *JMSSystemSession {
func (sh *JMSSystemSession) CreateNewSession(conn *websocket.Conn, prompt string) *JMSSystemSession {
id := uuid.New().String()
return &JMSSystemSession{
Id: id,
Websocket: conn,
CurrentAskInterrupt: false,
Prompt: prompt,
HistoryAsks: make([]string, 0),
JMSState: &schemas.JMSState{
ID: id, ActivateReview: schemas.Wait,
3 changes: 2 additions & 1 deletion pkg/jms/session.go
Original file line number Diff line number Diff line change
@@ -27,11 +27,12 @@ func getRemoteAddress(websocket *websocket.Conn) string {
return websocket.RemoteAddr().String()
}

func (sh *SessionHandler) CreateNewSession(authInfo *protobuf.TokenAuthInfo) *JMSSession {
func (sh *SessionHandler) CreateNewSession(authInfo *protobuf.TokenAuthInfo, prompt string) *JMSSession {
session := sh.createSession(authInfo)
return &JMSSession{
Session: session,
Websocket: sh.Websocket,
Prompt: prompt,
HistoryAsks: make([]string, 0),
CurrentAskInterrupt: false,
CommandACLs: authInfo.FilterRules,
11 changes: 10 additions & 1 deletion pkg/manager/openai.go
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ func NewClient(authToken, baseURL, proxy string) *openai.Client {
return openai.NewClientWithConfig(config)
}

func ChatGPT(ask *AskChatGPT) {
func ChatGPT(ask *AskChatGPT, prompt string) {
// TODO 做超时处理
ctx := context.Background()
messages := make([]openai.ChatCompletionMessage, 0)
@@ -84,6 +84,15 @@ func ChatGPT(ask *AskChatGPT) {
})
}

if prompt != "" {
messages = append([]openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: prompt,
},
}, messages...)
}

req := openai.ChatCompletionRequest{
Model: ask.Model,
Messages: messages,
1 change: 1 addition & 0 deletions pkg/manager/schemas.go
Original file line number Diff line number Diff line change
@@ -15,4 +15,5 @@ type ChatGPTParam struct {
BaseURL string
Proxy string
Model string
Prompt string
}
1 change: 1 addition & 0 deletions pkg/schemas/conversation.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package schemas
type AskRequest struct {
ConversationID string `json:"conversation_id,omitempty"`
Content string `json:"content"`
Prompt string `json:"prompt"`
}

type AskResponseType string

0 comments on commit 60b16fb

Please sign in to comment.