Skip to content

Commit

Permalink
#29: Refactor OpenAI client and related functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 21, 2023
1 parent c7d341c commit d922ee9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 87 deletions.
86 changes: 8 additions & 78 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"net/http"
"reflect"
"strings"

"github.com/cloudwego/hertz/pkg/app/client"
)

const (
Expand Down Expand Up @@ -75,7 +73,7 @@ type ChatUsage struct {
// Returns:
// - *ChatResponse: a pointer to a ChatResponse
// - error: An error if the request failed.
func (c *Client) Chat() (*ChatResponse, error) {
func (c *ProviderClient) Chat() (*ChatResponse, error) {
// Create a new chat request

slog.Info("creating chat request")
Expand All @@ -93,7 +91,7 @@ func (c *Client) Chat() (*ChatResponse, error) {
return resp, err
}

func (c *Client) CreateChatRequest(message []byte) *ChatRequest {
func (c *ProviderClient) CreateChatRequest(message []byte) *ChatRequest {
err := json.Unmarshal(message, &requestBody)
if err != nil {
slog.Error("Error:", err)
Expand Down Expand Up @@ -168,10 +166,10 @@ type ChatResponse struct {
}

// CreateChatResponse creates chat Response.
func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) {
func (c *ProviderClient) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatResponse, error) {
_ = ctx // keep this for future use

resp, err := c.createChatHTTP(r)
resp, err := c.createChatHTTP(r) // netpoll -> hertz does not yet support tls
if err != nil {
return nil, err
}
Expand All @@ -181,61 +179,7 @@ func (c *Client) CreateChatResponse(ctx context.Context, r *ChatRequest) (*ChatR
return resp, nil
}

/* will remove later
func (c *Client) createChatHertz(ctx context.Context, payload *ChatRequest) (*ChatResponse, error) {
slog.Info("running createChat")
if payload.StreamingFunc != nil {
payload.Stream = true
}
// Build request payload
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
// Build request
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}
req := &protocol.Request{}
res := &protocol.Response{}
req.Header.SetMethod(consts.MethodPost)
req.SetRequestURI(c.buildURL("/chat/completions", c.Provider.Model))
req.SetBody(payloadBytes)
req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey)
req.Header.Set("Content-Type", "application/json")
slog.Info("making request")
// Send request
err = c.httpClient.Do(ctx, req, res) //*client.Client
if err != nil {
slog.Error(err.Error())
fmt.Println(res.Body())
return nil, err
}
slog.Info("request returned")
defer res.ConnectionClose() // replaced r.Body.Close()
slog.Info(fmt.Sprintf("%d", res.StatusCode()))
if res.StatusCode() != http.StatusOK {
msg := fmt.Sprintf("API returned unexpected status code: %d", res.StatusCode())
return nil, fmt.Errorf("%s: %s", msg, err.Error()) // nolint:goerr113
}
// Parse response
var response ChatResponse
return &response, json.NewDecoder(bytes.NewReader(res.Body())).Decode(&response)
}
*/

func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) {
func (c *ProviderClient) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) {
slog.Info("running createChatHttp")

if payload.StreamingFunc != nil {
Expand Down Expand Up @@ -264,8 +208,7 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) {
req.Header.Set("Authorization", "Bearer "+c.Provider.APIKey)
req.Header.Set("Content-Type", "application/json")

httpClient := &http.Client{}
resp, err := httpClient.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
slog.Error(err.Error())
return nil, err
Expand All @@ -289,30 +232,17 @@ func (c *Client) createChatHTTP(payload *ChatRequest) (*ChatResponse, error) {
return &response, json.NewDecoder(resp.Body).Decode(&response)
}

func (c *Client) buildURL(suffix string) string {
func (c *ProviderClient) buildURL(suffix string) string {
slog.Info("request url: " + fmt.Sprintf("%s%s", c.baseURL, suffix))

// open ai implement:
return fmt.Sprintf("%s%s", c.baseURL, suffix)
}

func (c *Client) setModel() string {
func (c *ProviderClient) setModel() string {
if c.Provider.Model == "" {
return defaultChatModel
}

return c.Provider.Model
}

// HTTPClient returns a new Hertz HTTP client.
//
// It creates a new client using the client.NewClient() function and returns the client.
// If an error occurs during the creation of the client, it logs the error using slog.Error().
// The function returns the created client or nil if an error occurred.
func HTTPClient() *client.Client {
c, err := client.NewClient()
if err != nil {
slog.Error(err.Error())
}
return c
}
2 changes: 1 addition & 1 deletion pkg/providers/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestOpenAIClient(t *testing.T) {

payloadBytes, _ := json.Marshal(payload)

c, err := OpenAiClient(poolName, modelName, payloadBytes)
c, err := Client(poolName, modelName, payloadBytes)
if err != nil {
slog.Error(err.Error())
return
Expand Down
23 changes: 16 additions & 7 deletions pkg/providers/openai/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"time"

"gopkg.in/yaml.v2"

"glide/pkg/providers"

"github.com/cloudwego/hertz/pkg/app/client"
"github.com/go-playground/validator/v10"
)

Expand All @@ -35,13 +36,21 @@ var (
}
)

var httpClient = &http.Client{
Timeout: time.Second * 60,
Transport: &http.Transport{
MaxIdleConns: 90,
MaxIdleConnsPerHost: 5,
},
}

// Client is a client for the OpenAI API.
type Client struct {
type ProviderClient struct {
Provider providers.Provider `validate:"required"`
PoolName string `validate:"required"`
baseURL string `validate:"required"`
payload []byte `validate:"required"`
httpClient *client.Client `validate:"required"`
httpClient *http.Client `validate:"required"`
}

// OpenAiClient creates a new client for the OpenAI API.
Expand All @@ -53,9 +62,9 @@ type Client struct {
// Returns:
// - *Client: A pointer to the created client.
// - error: An error if the client creation failed.
func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, error) {
func Client(poolName string, modelName string, payload []byte) (*ProviderClient, error) {
// Read the YAML file
data, err := os.ReadFile("config.yaml") // TODO: How will this be accessed? Does it have to be read each time?
data, err := os.ReadFile("/Users/max/code/Glide/config.yaml") // TODO: How will this be accessed? Does it have to be read each time?
if err != nil {
return nil, fmt.Errorf("failed to read YAML file: %w", err)
}
Expand All @@ -81,12 +90,12 @@ func OpenAiClient(poolName string, modelName string, payload []byte) (*Client, e
}

// Create a new client
c := &Client{
c := &ProviderClient{
Provider: *selectedProvider,
PoolName: poolName,
baseURL: defaultBaseURL, // Set the appropriate base URL
payload: payload,
httpClient: HTTPClient(),
httpClient: httpClient,
}

v := validator.New()
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type Pool struct {
}

type Provider struct {
Name string `yaml:"name" validate:"required"`
Name string `yaml:"name" validate:"required"`
Model string `yaml:"model"`
APIKey string `yaml:"api_key" validate:"required"`
TimeoutMs int `yaml:"timeout_ms,omitempty"`
Expand Down

0 comments on commit d922ee9

Please sign in to comment.