Skip to content

Commit

Permalink
code style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
urjitbhatia committed Nov 7, 2023
1 parent 34f8671 commit d95517e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
16 changes: 8 additions & 8 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ const (
)

type Message struct {
Id string `json:"id"`
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int `json:"created_at"`
ThreadId string `json:"thread_id"`
ThreadID string `json:"thread_id"`
Role string `json:"role"`
Content []MessageContent `json:"content"`
FileIds []interface{} `json:"file_ids"`
AssistantId string `json:"assistant_id"`
RunId string `json:"run_id"`
AssistantID string `json:"assistant_id"`
RunID string `json:"run_id"`
Metadata map[string]any `json:"metadata"`

httpHeader
Expand Down Expand Up @@ -49,10 +49,10 @@ type MessageRequest struct {
}

type MessageFile struct {
Id string `json:"id"`
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int `json:"created_at"`
MessageId string `json:"message_id"`
MessageID string `json:"message_id"`

httpHeader
}
Expand Down Expand Up @@ -141,7 +141,7 @@ func (c *Client) ModifyMessage(
return
}

// RetrieveMessageFile fetches a message file
// RetrieveMessageFile fetches a message file.
func (c *Client) RetrieveMessageFile(
ctx context.Context,
threadID, messageID, fileID string,
Expand All @@ -156,7 +156,7 @@ func (c *Client) RetrieveMessageFile(
return
}

// ListMessageFiles fetches all files attached to a message
// ListMessageFiles fetches all files attached to a message.
func (c *Client) ListMessageFiles(
ctx context.Context,
threadID, messageID string,
Expand Down
60 changes: 32 additions & 28 deletions messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"net/http"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

// TestMessages Tests the messages endpoint of the API using the mocked server.
Expand All @@ -26,10 +27,10 @@ func TestMessages(t *testing.T) {
case http.MethodGet:
resBytes, _ := json.Marshal(
openai.MessageFile{
Id: fileID,
ID: fileID,
Object: "thread.message.file",
CreatedAt: 1699061776,
MessageId: messageID,
MessageID: messageID,
})
fmt.Fprintln(w, string(resBytes))
default:
Expand All @@ -45,10 +46,10 @@ func TestMessages(t *testing.T) {
case http.MethodGet:
resBytes, _ := json.Marshal(
openai.MessageFilesList{MessageFiles: []openai.MessageFile{{
Id: fileID,
ID: fileID,
Object: "thread.message.file",
CreatedAt: 0,
MessageId: messageID,
MessageID: messageID,
}}})
fmt.Fprintln(w, string(resBytes))
default:
Expand All @@ -68,10 +69,10 @@ func TestMessages(t *testing.T) {

resBytes, _ := json.Marshal(
openai.Message{
Id: messageID,
ID: messageID,
Object: "thread.message",
CreatedAt: 1234567890,
ThreadId: threadID,
ThreadID: threadID,
Role: "user",
Content: []openai.MessageContent{{
Type: "text",
Expand All @@ -81,18 +82,18 @@ func TestMessages(t *testing.T) {
},
}},
FileIds: nil,
AssistantId: "",
RunId: "",
AssistantID: "",
RunID: "",
Metadata: metadata,
})
fmt.Fprintln(w, string(resBytes))
case http.MethodGet:
resBytes, _ := json.Marshal(
openai.Message{
Id: messageID,
ID: messageID,
Object: "thread.message",
CreatedAt: 1234567890,
ThreadId: threadID,
ThreadID: threadID,
Role: "user",
Content: []openai.MessageContent{{
Type: "text",
Expand All @@ -102,8 +103,8 @@ func TestMessages(t *testing.T) {
},
}},
FileIds: nil,
AssistantId: "",
RunId: "",
AssistantID: "",
RunID: "",
Metadata: nil,
})
fmt.Fprintln(w, string(resBytes))
Expand All @@ -119,10 +120,10 @@ func TestMessages(t *testing.T) {
switch r.Method {
case http.MethodPost:
resBytes, _ := json.Marshal(openai.Message{
Id: messageID,
ID: messageID,
Object: "thread.message",
CreatedAt: 1234567890,
ThreadId: threadID,
ThreadID: threadID,
Role: "user",
Content: []openai.MessageContent{{
Type: "text",
Expand All @@ -132,18 +133,18 @@ func TestMessages(t *testing.T) {
},
}},
FileIds: nil,
AssistantId: "",
RunId: "",
AssistantID: "",
RunID: "",
Metadata: nil,
})
fmt.Fprintln(w, string(resBytes))
case http.MethodGet:
resBytes, _ := json.Marshal(openai.MessagesList{
Messages: []openai.Message{{
Id: messageID,
ID: messageID,
Object: "thread.message",
CreatedAt: 1234567890,
ThreadId: threadID,
ThreadID: threadID,
Role: "user",
Content: []openai.MessageContent{{
Type: "text",
Expand All @@ -153,8 +154,8 @@ func TestMessages(t *testing.T) {
},
}},
FileIds: nil,
AssistantId: "",
RunId: "",
AssistantID: "",
RunID: "",
Metadata: nil,
}}})
fmt.Fprintln(w, string(resBytes))
Expand All @@ -175,6 +176,9 @@ func TestMessages(t *testing.T) {
Metadata: nil,
})
checks.NoError(t, err, "CreateMessage error")
if msg.ID != messageID {
t.Fatalf("unexpected message id: '%s'", msg.ID)
}

var msgs openai.MessagesList
msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil)
Expand All @@ -185,8 +189,8 @@ func TestMessages(t *testing.T) {

msg, err = client.RetrieveMessage(ctx, threadID, messageID)
checks.NoError(t, err, "RetrieveMessage error")
if msg.Id != messageID {
t.Fatalf("unexpected message id: '%s'", msg.Id)
if msg.ID != messageID {
t.Fatalf("unexpected message id: '%s'", msg.ID)
}

msg, err = client.ModifyMessage(ctx, threadID, messageID,
Expand All @@ -202,8 +206,8 @@ func TestMessages(t *testing.T) {
var msgFile openai.MessageFile
msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID)
checks.NoError(t, err, "RetrieveMessageFile error")
if msgFile.Id != fileID {
t.Fatalf("unexpected message file id: '%s'", msgFile.Id)
if msgFile.ID != fileID {
t.Fatalf("unexpected message file id: '%s'", msgFile.ID)
}

var msgFiles openai.MessageFilesList
Expand All @@ -212,7 +216,7 @@ func TestMessages(t *testing.T) {
if len(msgFiles.MessageFiles) != 1 {
t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles))
}
if msgFiles.MessageFiles[0].Id != fileID {
t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].Id)
if msgFiles.MessageFiles[0].ID != fileID {
t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID)
}
}

0 comments on commit d95517e

Please sign in to comment.