diff --git a/model/chatmodel/anthropic.go b/model/chatmodel/anthropic.go index 47907c0..ce1b7ec 100644 --- a/model/chatmodel/anthropic.go +++ b/model/chatmodel/anthropic.go @@ -67,8 +67,9 @@ func NewAnthropicFromClient(client AnthropicClient, optFns ...func(o *AnthropicO CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, }, - ModelName: "claude-v1", - MaxTokens: 256, + ModelName: "claude-v1", + Temperature: 0.5, + MaxTokens: 256, } for _, fn := range optFns { diff --git a/model/chatmodel/anthropic_test.go b/model/chatmodel/anthropic_test.go index 45248a9..57f8acd 100644 --- a/model/chatmodel/anthropic_test.go +++ b/model/chatmodel/anthropic_test.go @@ -70,6 +70,23 @@ func TestAnthropic(t *testing.T) { assert.Nil(t, result, "Expected nil result") }) }) + + t.Run("Type", func(t *testing.T) { + assert.Equal(t, "chatmodel.Anthropic", anthropicModel.Type()) + }) + + t.Run("Callbacks", func(t *testing.T) { + assert.Equal(t, anthropicModel.opts.CallbackOptions.Callbacks, anthropicModel.Callbacks()) + }) + + t.Run("InvocationParams", func(t *testing.T) { + // Call the InvocationParams method + params := anthropicModel.InvocationParams() + + // Assert the result + assert.Equal(t, "claude-v1", params["model_name"]) + assert.Equal(t, float32(0.5), params["temperature"]) + }) } func TestConvertMessagesToAnthropicPrompt(t *testing.T) { diff --git a/model/chatmodel/ollama_test.go b/model/chatmodel/ollama_test.go new file mode 100644 index 0000000..85bd5e6 --- /dev/null +++ b/model/chatmodel/ollama_test.go @@ -0,0 +1,128 @@ +package chatmodel + +import ( + "context" + "errors" + "testing" + + "github.com/hupe1980/golc/schema" + "github.com/stretchr/testify/assert" + + "github.com/hupe1980/golc/integration/ollama" +) + +func TestOllama(t *testing.T) { + t.Parallel() + + t.Run("Generate", func(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{ + GenerateChatFunc: func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) { + assert.Equal(t, "llama2", req.Model) + assert.Len(t, req.Messages, 2) + assert.Equal(t, "user", req.Messages[0].Role) + assert.Equal(t, "Hello", req.Messages[0].Content) + assert.Equal(t, "assistant", req.Messages[1].Role) + assert.Equal(t, "How can I help you?", req.Messages[1].Content) + + return &ollama.ChatResponse{ + Message: &ollama.Message{ + Role: "assistant", + Content: "I can help you with that.", + }, + }, nil + }, + } + + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + // Simulate chat messages + messages := []schema.ChatMessage{ + schema.NewHumanChatMessage("Hello"), + schema.NewAIChatMessage("How can I help you?"), + } + + // Run the model + result, err := ollamaModel.Generate(context.Background(), messages) + assert.NoError(t, err) + + // Check the result + assert.Len(t, result.Generations, 1) + assert.Equal(t, "I can help you with that.", result.Generations[0].Text) + }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{ + GenerateChatFunc: func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) { + return nil, errors.New("error generating chat") + }, + } + + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + messages := []schema.ChatMessage{ + schema.NewHumanChatMessage("Hello"), + schema.NewAIChatMessage("How can I help you?"), + } + + result, err := ollamaModel.Generate(context.Background(), messages) + assert.Error(t, err) + assert.Nil(t, result) + }) + }) + + t.Run("Type", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + assert.Equal(t, "chatmodel.Ollama", ollamaModel.Type()) + }) + + t.Run("Callbacks", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + assert.Equal(t, ollamaModel.opts.CallbackOptions.Callbacks, ollamaModel.Callbacks()) + }) + + t.Run("InvocationParams", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + params := ollamaModel.InvocationParams() + assert.NotNil(t, params) + assert.Equal(t, float32(0.7), params["temperature"]) + assert.Equal(t, 256, params["max_tokens"]) + }) +} + +// mockOllamaClient is a mock implementation of the chatmodel.OllamaClient interface. +type mockOllamaClient struct { + GenerateChatFunc func(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) +} + +// GenerateChat is the mock implementation of the GenerateChat method for mockOllamaClient. +func (m *mockOllamaClient) GenerateChat(ctx context.Context, req *ollama.ChatRequest) (*ollama.ChatResponse, error) { + if m.GenerateChatFunc != nil { + return m.GenerateChatFunc(ctx, req) + } + + return nil, errors.New("GenerateChatFunc not implemented") +} diff --git a/model/llm/ai21_test.go b/model/llm/ai21_test.go index a559b93..a254f85 100644 --- a/model/llm/ai21_test.go +++ b/model/llm/ai21_test.go @@ -9,23 +9,9 @@ import ( "github.com/stretchr/testify/assert" ) -// MockAI21Client is a custom mock implementation of the AI21Client interface. -type MockAI21Client struct { - CreateCompletionFunc func(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) -} - -// CreateCompletion mocks the CreateCompletion method of AI21Client. -func (m *MockAI21Client) CreateCompletion(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) { - if m.CreateCompletionFunc != nil { - return m.CreateCompletionFunc(ctx, model, req) - } - - return nil, errors.New("CreateCompletionFunc not implemented") -} - func TestAI21(t *testing.T) { // Initialize the AI21 client with a mock client - mockClient := &MockAI21Client{} + mockClient := &mockAI21Client{} llm, err := NewAI21FromClient(mockClient, func(o *AI21Options) { o.Model = "j2-mid" o.Temperature = 0.7 @@ -80,4 +66,33 @@ func TestAI21(t *testing.T) { assert.Nil(t, result) assert.Equal(t, expectedError, err) }) + + t.Run("Type", func(t *testing.T) { + assert.Equal(t, "llm.AI21", llm.Type()) + }) + + t.Run("Callbacks", func(t *testing.T) { + assert.Equal(t, llm.opts.CallbackOptions.Callbacks, llm.Callbacks()) + }) + + t.Run("InvocationParams", func(t *testing.T) { + params := llm.InvocationParams() + assert.NotNil(t, params) + assert.Equal(t, 0.7, params["temperature"]) + assert.Equal(t, "j2-mid", params["model"]) + }) +} + +// mockAI21Client is a custom mock implementation of the AI21Client interface. +type mockAI21Client struct { + CreateCompletionFunc func(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) +} + +// CreateCompletion mocks the CreateCompletion method of AI21Client. +func (m *mockAI21Client) CreateCompletion(ctx context.Context, model string, req *ai21.CompleteRequest) (*ai21.CompleteResponse, error) { + if m.CreateCompletionFunc != nil { + return m.CreateCompletionFunc(ctx, model, req) + } + + return nil, errors.New("CreateCompletionFunc not implemented") } diff --git a/model/llm/ollama_test.go b/model/llm/ollama_test.go new file mode 100644 index 0000000..7593e32 --- /dev/null +++ b/model/llm/ollama_test.go @@ -0,0 +1,109 @@ +package llm + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hupe1980/golc/integration/ollama" +) + +func TestOllama(t *testing.T) { + t.Parallel() + + t.Run("Generate", func(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{ + GenerateFunc: func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) { + assert.Equal(t, "llama2", req.Model) + assert.Equal(t, "Hello", req.Prompt) + + return &ollama.GenerateResponse{ + Response: "I can help you with that.", + }, nil + }, + } + + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + // Run the model + result, err := ollamaModel.Generate(context.Background(), "Hello") + assert.NoError(t, err) + + // Check the result + assert.Len(t, result.Generations, 1) + assert.Equal(t, "I can help you with that.", result.Generations[0].Text) + }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{ + GenerateFunc: func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) { + return nil, errors.New("error generating chat") + }, + } + + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + result, err := ollamaModel.Generate(context.Background(), "Hello") + assert.Error(t, err) + assert.Nil(t, result) + }) + }) + + t.Run("Type", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + assert.Equal(t, "llm.Ollama", ollamaModel.Type()) + }) + + t.Run("Callbacks", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + assert.Equal(t, ollamaModel.opts.CallbackOptions.Callbacks, ollamaModel.Callbacks()) + }) + + t.Run("InvocationParams", func(t *testing.T) { + t.Parallel() + + mockClient := &mockOllamaClient{} + ollamaModel, err := NewOllama(mockClient) + assert.NoError(t, err) + + params := ollamaModel.InvocationParams() + assert.NotNil(t, params) + assert.Equal(t, float32(0.7), params["temperature"]) + assert.Equal(t, 256, params["max_tokens"]) + }) +} + +// mockOllamaClient is a mock implementation of the llm.OllamaClient interface. +type mockOllamaClient struct { + GenerateFunc func(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) +} + +// Generate is the mock implementation of the Generate method for mockOllamaClient. +func (m *mockOllamaClient) Generate(ctx context.Context, req *ollama.GenerateRequest) (*ollama.GenerateResponse, error) { + if m.GenerateFunc != nil { + return m.GenerateFunc(ctx, req) + } + + return nil, errors.New("GenerateChatFunc not implemented") +}