Skip to content

Commit

Permalink
Merge branch 'xqdoo00o-master' into Add-Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
bi1101 committed Mar 16, 2024
2 parents 920a5e1 + f0bbcfa commit 6733db3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
22 changes: 19 additions & 3 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"io"
"net/http"
"sync"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
Expand Down Expand Up @@ -141,11 +142,26 @@ func nightmare(c *gin.Context) {
}
uid := uuid.NewString()
var err error
err = chatgpt.InitWSConn(token, uid, proxy_url)
var chat_require *chatgpt.ChatRequire
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err = chatgpt.InitWSConn(token, uid, proxy_url)
}()
go func() {
defer wg.Done()
chat_require = chatgpt.CheckRequire(token, puid, proxy_url)
}()
wg.Wait()
if err != nil {
c.JSON(500, gin.H{"error": "unable to create ws tunnel"})
return
}
if chat_require == nil {
c.JSON(500, gin.H{"error": "unable to check chat requirement"})
return
}
chat_require := chatgpt.CheckRequire(token, puid, proxy_url)
// Convert the chat request to a ChatGPT request
translated_request := chatgpt_request_converter.ConvertAPIRequest(original_request, puid, chat_require.Arkose.Required, proxy_url)

Expand Down Expand Up @@ -174,7 +190,7 @@ func nightmare(c *gin.Context) {
translated_request.Action = "continue"
translated_request.ConversationID = continue_info.ConversationID
translated_request.ParentMessageID = continue_info.ParentID
if strings.HasPrefix(original_request.Model, "gpt-4") {
if chat_require.Arkose.Required {
chatgpt_request_converter.RenewTokenForRequest(&translated_request, puid, proxy_url)
}
response, err = chatgpt.POSTconversation(translated_request, token, puid, chat_require.Token, proxy_url)
Expand Down
36 changes: 25 additions & 11 deletions internal/chatgpt/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import (
"sync"
"time"

hp "net/http"

"github.com/gorilla/websocket"

http "github.com/bogdanfinn/fhttp"
Expand Down Expand Up @@ -76,14 +74,10 @@ func getWSURL(token string, retry int) (string, error) {
}

func createWSConn(url string, connInfo *connInfo, retry int) error {
header := make(hp.Header)
header.Add("Sec-WebSocket-Protocol", "json.reliable.webpubsub.azure.v1")
dialer := websocket.Dialer{
Proxy: hp.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
EnableCompression: true,
}
conn, _, err := dialer.Dial(url, header)
dialer := websocket.DefaultDialer
dialer.EnableCompression = true
dialer.Subprotocols = []string{"json.reliable.webpubsub.azure.v1"}
conn, _, err := dialer.Dial(url, nil)
if err != nil {
if retry > 3 {
return err
Expand Down Expand Up @@ -234,6 +228,7 @@ func CheckRequire(access_token string, puid string, proxy string) *ChatRequire {
}
return &require
}

func POSTconversation(message chatgpt_types.ChatGPTRequest, access_token string, puid string, chat_token string, proxy string) (*http.Response, error) {
if proxy != "" {
client.SetProxy(proxy)
Expand Down Expand Up @@ -379,6 +374,8 @@ func Handler(c *gin.Context, response *http.Response, token string, puid string,
var wssUrl string
var connInfo *connInfo
var wsSeq int
var isWSInterrupt bool = false
var interruptTimer *time.Timer

if !strings.Contains(response.Header.Get("Content-Type"), "text/event-stream") {
isWSS = true
Expand All @@ -399,9 +396,21 @@ func Handler(c *gin.Context, response *http.Response, token string, puid string,
if isWSS {
var messageType int
var message []byte
if isWSInterrupt {
if interruptTimer == nil {
interruptTimer = time.NewTimer(10 * time.Second)
}
select {
case <-interruptTimer.C:
c.JSON(500, gin.H{"error": "WS interrupt & new WS timeout"})
return "", nil
default:
goto reader
}
}
reader:
messageType, message, err = connInfo.conn.ReadMessage()
if err != nil {
println(err.Error())
connInfo.ticker.Stop()
connInfo.conn.Close()
connInfo.conn = nil
Expand All @@ -410,6 +419,7 @@ func Handler(c *gin.Context, response *http.Response, token string, puid string,
c.JSON(500, gin.H{"error": err.Error()})
return "", nil
}
isWSInterrupt = true
connInfo.conn.WriteMessage(websocket.TextMessage, []byte("{\"type\":\"sequenceAck\",\"sequenceId\":"+strconv.Itoa(wsSeq)+"}"))
continue
}
Expand All @@ -428,6 +438,10 @@ func Handler(c *gin.Context, response *http.Response, token string, puid string,
if err != nil {
continue
}
if isWSInterrupt {
isWSInterrupt = false
interruptTimer.Stop()
}
line = string(bodyByte)
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion typings/chatgpt/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type chatgpt_author struct {

type ChatGPTRequest struct {
Action string `json:"action"`
Messages []chatgpt_message `json:"messages"`
Messages []chatgpt_message `json:"messages,omitempty"`
ParentMessageID string `json:"parent_message_id,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
Model string `json:"model"`
Expand Down

0 comments on commit 6733db3

Please sign in to comment.