Skip to content

Commit

Permalink
Add testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 17, 2023
1 parent 91dbfe2 commit e28935e
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 17 deletions.
5 changes: 3 additions & 2 deletions model/chatmodel/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ func NewAnthropicFromClient(client AnthropicClient, optFns ...func(o *AnthropicO
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
ModelName: "claude-v1",
MaxTokens: 256,
ModelName: "claude-v1",
Temperature: 0.5,
MaxTokens: 256,
}

for _, fn := range optFns {
Expand Down
17 changes: 17 additions & 0 deletions model/chatmodel/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ func TestAnthropic(t *testing.T) {
assert.Nil(t, result, "Expected nil result")
})
})

t.Run("Type", func(t *testing.T) {
assert.Equal(t, "chatmodel.Anthropic", anthropicModel.Type())
})

t.Run("Callbacks", func(t *testing.T) {
assert.Equal(t, anthropicModel.opts.CallbackOptions.Callbacks, anthropicModel.Callbacks())
})

t.Run("InvocationParams", func(t *testing.T) {
// Call the InvocationParams method
params := anthropicModel.InvocationParams()

// Assert the result
assert.Equal(t, "claude-v1", params["model_name"])
assert.Equal(t, float32(0.5), params["temperature"])
})
}

func TestConvertMessagesToAnthropicPrompt(t *testing.T) {
Expand Down
128 changes: 128 additions & 0 deletions model/chatmodel/ollama_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package chatmodel

import (
"context"
"errors"
"testing"

"github.com/hupe1980/golc/schema"
"github.com/stretchr/testify/assert"

"github.com/hupe1980/golc/integration/ollama"
)

func TestOllama(t *testing.T) {
t.Parallel()

t.Run("Generate", func(t *testing.T) {
t.Parallel()

t.Run("Success", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{
GenerateChatFunc: func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) {
assert.Equal(t, "llama2", req.Model)
assert.Len(t, req.Messages, 2)
assert.Equal(t, "user", req.Messages[0].Role)
assert.Equal(t, "Hello", req.Messages[0].Content)
assert.Equal(t, "assistant", req.Messages[1].Role)
assert.Equal(t, "How can I help you?", req.Messages[1].Content)

return &ollama.ChatResponse{
Message: &ollama.Message{
Role: "assistant",
Content: "I can help you with that.",
},
}, nil
},
}

ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

// Simulate chat messages
messages := []schema.ChatMessage{
schema.NewHumanChatMessage("Hello"),
schema.NewAIChatMessage("How can I help you?"),
}

// Run the model
result, err := ollamaModel.Generate(context.Background(), messages)
assert.NoError(t, err)

// Check the result
assert.Len(t, result.Generations, 1)
assert.Equal(t, "I can help you with that.", result.Generations[0].Text)
})

t.Run("Error", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{
GenerateChatFunc: func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) {
return nil, errors.New("error generating chat")
},
}

ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

messages := []schema.ChatMessage{
schema.NewHumanChatMessage("Hello"),
schema.NewAIChatMessage("How can I help you?"),
}

result, err := ollamaModel.Generate(context.Background(), messages)
assert.Error(t, err)
assert.Nil(t, result)
})
})

t.Run("Type", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

assert.Equal(t, "chatmodel.Ollama", ollamaModel.Type())
})

t.Run("Callbacks", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

assert.Equal(t, ollamaModel.opts.CallbackOptions.Callbacks, ollamaModel.Callbacks())
})

t.Run("InvocationParams", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

params := ollamaModel.InvocationParams()
assert.NotNil(t, params)
assert.Equal(t, float32(0.7), params["temperature"])
assert.Equal(t, 256, params["max_tokens"])
})
}

// mockOllamaClient is a mock implementation of the chatmodel.OllamaClient interface.
type mockOllamaClient struct {
GenerateChatFunc func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error)
}

// GenerateChat is the mock implementation of the GenerateChat method for mockOllamaClient.
func (m *mockOllamaClient) GenerateChat(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) {
if m.GenerateChatFunc != nil {
return m.GenerateChatFunc(ctx, req)
}

return nil, errors.New("GenerateChatFunc not implemented")
}
45 changes: 30 additions & 15 deletions model/llm/ai21_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,9 @@ import (
"github.com/stretchr/testify/assert"
)

// MockAI21Client is a custom mock implementation of the AI21Client interface.
type MockAI21Client struct {
CreateCompletionFunc func(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error)
}

// CreateCompletion mocks the CreateCompletion method of AI21Client.
func (m *MockAI21Client) CreateCompletion(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) {
if m.CreateCompletionFunc != nil {
return m.CreateCompletionFunc(ctx, model, req)
}

return nil, errors.New("CreateCompletionFunc not implemented")
}

func TestAI21(t *testing.T) {
// Initialize the AI21 client with a mock client
mockClient := &MockAI21Client{}
mockClient := &mockAI21Client{}
llm, err := NewAI21FromClient(mockClient, func(o *AI21Options) {
o.Model = "j2-mid"
o.Temperature = 0.7
Expand Down Expand Up @@ -80,4 +66,33 @@ func TestAI21(t *testing.T) {
assert.Nil(t, result)
assert.Equal(t, expectedError, err)
})

t.Run("Type", func(t *testing.T) {
assert.Equal(t, "llm.AI21", llm.Type())
})

t.Run("Callbacks", func(t *testing.T) {
assert.Equal(t, llm.opts.CallbackOptions.Callbacks, llm.Callbacks())
})

t.Run("InvocationParams", func(t *testing.T) {
params := llm.InvocationParams()
assert.NotNil(t, params)
assert.Equal(t, 0.7, params["temperature"])
assert.Equal(t, "j2-mid", params["model"])
})
}

// mockAI21Client is a custom mock implementation of the AI21Client interface.
type mockAI21Client struct {
CreateCompletionFunc func(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error)
}

// CreateCompletion mocks the CreateCompletion method of AI21Client.
func (m *mockAI21Client) CreateCompletion(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) {
if m.CreateCompletionFunc != nil {
return m.CreateCompletionFunc(ctx, model, req)
}

return nil, errors.New("CreateCompletionFunc not implemented")
}
109 changes: 109 additions & 0 deletions model/llm/ollama_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package llm

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"

"github.com/hupe1980/golc/integration/ollama"
)

func TestOllama(t *testing.T) {
t.Parallel()

t.Run("Generate", func(t *testing.T) {
t.Parallel()

t.Run("Success", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{
GenerateFunc: func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) {
assert.Equal(t, "llama2", req.Model)
assert.Equal(t, "Hello", req.Prompt)

return &ollama.GenerateResponse{
Response: "I can help you with that.",
}, nil
},
}

ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

// Run the model
result, err := ollamaModel.Generate(context.Background(), "Hello")
assert.NoError(t, err)

// Check the result
assert.Len(t, result.Generations, 1)
assert.Equal(t, "I can help you with that.", result.Generations[0].Text)
})

t.Run("Error", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{
GenerateFunc: func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) {
return nil, errors.New("error generating chat")
},
}

ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

result, err := ollamaModel.Generate(context.Background(), "Hello")
assert.Error(t, err)
assert.Nil(t, result)
})
})

t.Run("Type", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

assert.Equal(t, "llm.Ollama", ollamaModel.Type())
})

t.Run("Callbacks", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

assert.Equal(t, ollamaModel.opts.CallbackOptions.Callbacks, ollamaModel.Callbacks())
})

t.Run("InvocationParams", func(t *testing.T) {
t.Parallel()

mockClient := &mockOllamaClient{}
ollamaModel, err := NewOllama(mockClient)
assert.NoError(t, err)

params := ollamaModel.InvocationParams()
assert.NotNil(t, params)
assert.Equal(t, float32(0.7), params["temperature"])
assert.Equal(t, 256, params["max_tokens"])
})
}

// mockOllamaClient is a mock implementation of the llm.OllamaClient interface.
type mockOllamaClient struct {
GenerateFunc func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error)
}

// Generate is the mock implementation of the Generate method for mockOllamaClient.
func (m *mockOllamaClient) Generate(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) {
if m.GenerateFunc != nil {
return m.GenerateFunc(ctx, req)
}

return nil, errors.New("GenerateChatFunc not implemented")
}

0 comments on commit e28935e

Please sign in to comment.