Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms: extract common llms errors to shared declaration #925

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
02935ed
llms: refactor ErrIncompleteEmbedding to shared declaration
wgeorgecook Jun 22, 2024
cffb37c
llms: remove unused error declarations
wgeorgecook Jun 22, 2024
90290de
llms: refactor ErrEmptyResponse to shared declaration
wgeorgecook Jun 22, 2024
cba65c1
llms: add comments for shared errors
wgeorgecook Jun 22, 2024
3e36497
llms: remove unused error declarations
wgeorgecook Jun 22, 2024
2c2271c
golangci-lint: fix missing periods in comments
wgeorgecook Jun 22, 2024
bc8f270
llms/googleai: extract defaults to static declarations for golangci
wgeorgecook Jun 22, 2024
9d6ddb4
llms: gofumpt files for golangci
wgeorgecook Jun 22, 2024
58debc6
llms/maritaca: statically declare parts length for golangci
wgeorgecook Jun 22, 2024
1986ee0
llms/openai: use canonical header names for golangci
wgeorgecook Jun 22, 2024
0d3ca54
llms/cloudflare: use http status code variable for golangci
wgeorgecook Jun 22, 2024
6c02795
llms/googleai: update nolint directive for golangci
wgeorgecook Jun 22, 2024
2660d83
llms/bedrock: extract default tokens to statically defined shared var…
wgeorgecook Jun 22, 2024
19374c6
llms/cohere: use canonical headers for golangci
wgeorgecook Jun 22, 2024
d5c97f4
llms/anthropic: use canonical headers for golangci
wgeorgecook Jun 22, 2024
7b3ca56
llms/huggingface: replace http handler require calls with t.Error on …
wgeorgecook Jun 22, 2024
b18e698
Merge branch 'main' of github.com:wgeorgecook/langchaingo into Extrac…
wgeorgecook Aug 10, 2024
a1814b1
llms: change language for error docstring from 'is thrown' to 'is ret…
wgeorgecook Sep 13, 2024
1c3fee1
Merge branch 'main' into ExtractCommonLlmsErrorsToSharedDeclaration
wgeorgecook Sep 13, 2024
aae4185
llms: use imported ErrEmptyResponse
wgeorgecook Sep 13, 2024
498bbdb
golangci-lint: remove duplicate 'is' in docstring
wgeorgecook Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable")
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
ErrInvalidContentType = errors.New("invalid content type")
Expand Down Expand Up @@ -92,7 +91,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten

func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) {
if len(messages) == 0 || len(messages[0].Parts) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

msg0 := messages[0]
Expand Down Expand Up @@ -153,7 +152,7 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
return nil, fmt.Errorf("anthropic: failed to create message: %w", err)
}
if result == nil {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

choices := make([]*llms.ContentChoice, len(result.Content))
Expand Down
3 changes: 0 additions & 3 deletions llms/anthropic/internal/anthropicclient/anthropicclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ const (
defaultModel = "claude-3-5-sonnet-20240620"
)

// ErrEmptyResponse is returned when the Anthropic API returns an empty response.
var ErrEmptyResponse = errors.New("empty response")

// Client is a client for the Anthropic API.
type Client struct {
token string
Expand Down
7 changes: 7 additions & 0 deletions llms/bedrock/internal/bedrockclient/defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package bedrockclient

const (
DefaultMaxTokenLength2048 = 2048
DefaultMaxTokenLength512 = 512
DefaultMaxTokenLength20 = 20
)
2 changes: 1 addition & 1 deletion llms/bedrock/internal/bedrockclient/provider_ai21.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func createAi21Completion(ctx context.Context, client *bedrockruntime.Client, mo
Prompt: txt,
Temperature: options.Temperature,
TopP: options.TopP,
MaxTokens: getMaxTokens(options.MaxTokens, 2048),
MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength2048),
StopSequences: options.StopWords,
CountPenalty: struct {
Scale float64 `json:"scale"`
Expand Down
2 changes: 1 addition & 1 deletion llms/bedrock/internal/bedrockclient/provider_amazon.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func createAmazonCompletion(ctx context.Context,
inputContent := amazonTextGenerationInput{
InputText: txt,
TextGenerationConfig: amazonTextGenerationConfigInput{
MaxTokens: getMaxTokens(options.MaxTokens, 512),
MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength512),
TopP: options.TopP,
Temperature: options.Temperature,
StopSequences: options.StopWords,
Expand Down
2 changes: 1 addition & 1 deletion llms/bedrock/internal/bedrockclient/provider_anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func createAnthropicCompletion(ctx context.Context,

input := anthropicTextGenerationInput{
AnthropicVersion: AnthropicLatestVersion,
MaxTokens: getMaxTokens(options.MaxTokens, 2048),
MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength2048),
System: systemPrompt,
Messages: inputContents,
Temperature: options.Temperature,
Expand Down
2 changes: 1 addition & 1 deletion llms/bedrock/internal/bedrockclient/provider_cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func createCohereCompletion(ctx context.Context,
Temperature: options.Temperature,
P: options.TopP,
K: options.TopK,
MaxTokens: getMaxTokens(options.MaxTokens, 20),
MaxTokens: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength20),
StopSequences: options.StopWords,
NumGenerations: options.CandidateCount,
}
Expand Down
2 changes: 1 addition & 1 deletion llms/bedrock/internal/bedrockclient/provider_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func createMetaCompletion(ctx context.Context,
Prompt: txt,
Temperature: options.Temperature,
TopP: options.TopP,
MaxGenLen: getMaxTokens(options.MaxTokens, 512),
MaxGenLen: getMaxTokens(options.MaxTokens, DefaultMaxTokenLength512),
}

body, err := json.Marshal(input)
Expand Down
9 changes: 2 additions & 7 deletions llms/cloudflare/cloudflarellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ import (
"github.com/tmc/langchaingo/llms/cloudflare/internal/cloudflareclient"
)

var (
ErrEmptyResponse = errors.New("no response")
ErrIncompleteEmbedding = errors.New("not all input got embedded")
)

// LLM is a cloudflare LLM implementation.
type LLM struct {
CallbacksHandler callbacks.Handler
Expand Down Expand Up @@ -147,11 +142,11 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo
}

if len(res.Result.Data) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

if len(inputTexts) != len(res.Result.Data) {
return res.Result.Data, ErrIncompleteEmbedding
return res.Result.Data, llms.ErrIncompleteEmbedding
}

return res.Result.Data, nil
Expand Down
6 changes: 3 additions & 3 deletions llms/cloudflare/internal/cloudflareclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, texts *CreateEmbeddingRequ
return nil, err
}

if resp.StatusCode > 299 {
if resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("error: %s", body)
}

Expand Down Expand Up @@ -81,7 +81,7 @@ func (c *Client) GenerateContent(ctx context.Context, request *GenerateContentRe
return nil, err
}

if response.StatusCode > 299 {
if response.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("error: %s", body)
}

Expand Down Expand Up @@ -165,7 +165,7 @@ func (c *Client) Summarize(ctx context.Context, inputText string, maxLength int)
return nil, err
}

if resp.StatusCode > 299 {
if resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("error: %s", body)
}

Expand Down
3 changes: 1 addition & 2 deletions llms/cohere/coherellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the COHERE_API_KEY key, set it in the COHERE_API_KEY environment variable")
ErrMissingToken = errors.New("missing the COHERE_API_KEY key, set it in the COHERE_API_KEY environment variable")

ErrUnexpectedResponseLength = errors.New("unexpected length of response")
)
Expand Down
12 changes: 5 additions & 7 deletions llms/cohere/internal/cohereclient/cohereclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ import (
"strings"

"github.com/cohere-ai/tokenizer"
"github.com/tmc/langchaingo/llms"
)

var (
ErrEmptyResponse = errors.New("empty response")
ErrModelNotFound = errors.New("model not found")
)
var ErrModelNotFound = errors.New("model not found")

type Client struct {
token string
Expand Down Expand Up @@ -111,8 +109,8 @@ func (c *Client) CreateGeneration(ctx context.Context, r *GenerationRequest) (*G
return nil, fmt.Errorf("create request: %w", err)
}

req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", "bearer "+c.token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "bearer "+c.token)

res, err := c.httpClient.Do(req)
if err != nil {
Expand All @@ -129,7 +127,7 @@ func (c *Client) CreateGeneration(ctx context.Context, r *GenerationRequest) (*G
if strings.HasPrefix(response.Message, "model not found") {
return nil, ErrModelNotFound
}
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

var generation Generation
Expand Down
5 changes: 1 addition & 4 deletions llms/ernie/erniellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ import (
"github.com/tmc/langchaingo/llms/ernie/internal/ernieclient"
)

var (
ErrEmptyResponse = errors.New("no response")
ErrCodeResponse = errors.New("has error code")
)
var ErrCodeResponse = errors.New("has error code")

type LLM struct {
client *ernieclient.Client
Expand Down
5 changes: 3 additions & 2 deletions llms/ernie/internal/ernieclient/ernieclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import (
"net/http"
"strings"
"time"

"github.com/tmc/langchaingo/llms"
)

var (
ErrNotSetAuth = errors.New("both accessToken and apiKey secretKey are not set")
ErrCompletionCode = errors.New("completion API returned unexpected status code")
ErrAccessTokenCode = errors.New("get access_token API returned unexpected status code")
ErrEmbeddingCode = errors.New("embedding API returned unexpected status code")
ErrEmptyResponse = errors.New("empty response")
)

// Client is a client for the ERNIE API.
Expand Down Expand Up @@ -285,7 +286,7 @@ func (c *Client) CreateChat(ctx context.Context, r *ChatRequest) (*ChatResponse,
}

if resp.Result == "" && resp.FunctionCall == nil {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

return resp, nil
Expand Down
11 changes: 11 additions & 0 deletions llms/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package llms

import "errors"

var (
// ErrEmptyResponse is returned when an LLM returns an empty response.
ErrEmptyResponse = errors.New("no response")
// ErrIncompleteEmbedding is returned when the length of an embedding
// request does not match the length of the returned embeddings array.
ErrIncompleteEmbedding = errors.New("not all input got embedded")
)
15 changes: 6 additions & 9 deletions llms/googleai/internal/palmclient/palmclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ var (
)

var defaultParameters = map[string]interface{}{ //nolint:gochecknoglobals
"temperature": 0.2, //nolint:gomnd
"maxOutputTokens": 256, //nolint:gomnd
"topP": 0.8, //nolint:gomnd
"topK": 40, //nolint:gomnd
"temperature": 0.2, //nolint:all
"maxOutputTokens": 256, //nolint:all
"topP": 0.8, //nolint:all
"topK": 40, //nolint:all
}

const (
Expand Down Expand Up @@ -65,9 +65,6 @@ func New(ctx context.Context, projectID, location string, opts ...option.ClientO
}, nil
}

// ErrEmptyResponse is returned when the OpenAI API returns an empty response.
var ErrEmptyResponse = errors.New("empty response")

// CompletionRequest is a request to create a completion.
type CompletionRequest struct {
Prompts []string `json:"prompts"`
Expand Down Expand Up @@ -290,7 +287,7 @@ func (c *PaLMClient) batchPredict(ctx context.Context, model string, prompts []s
return nil, err
}
if len(resp.GetPredictions()) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}
return resp.GetPredictions(), nil
}
Expand Down Expand Up @@ -329,7 +326,7 @@ func (c *PaLMClient) chat(ctx context.Context, r *ChatRequest) ([]*structpb.Valu
return nil, err
}
if len(resp.GetPredictions()) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}
return resp.GetPredictions(), nil
}
Expand Down
33 changes: 23 additions & 10 deletions llms/googleai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,31 @@ type Options struct {
ClientOptions []option.ClientOption
}

const (
CloudProject = ""
CloudLocation = ""
DefaultModel = "gemini-pro"
DefaultEmbeddingModel = "embedding-001"
DefaultCandidateCount = 1
DefaultMaxTokens = 2048
DefaultTemperature = 0.5
DefaultTopK = 3
DefaultTopP = 0.95
DefaultHarmThreshold = HarmBlockOnlyHigh
)

func DefaultOptions() Options {
return Options{
CloudProject: "",
CloudLocation: "",
DefaultModel: "gemini-pro",
DefaultEmbeddingModel: "embedding-001",
DefaultCandidateCount: 1,
DefaultMaxTokens: 2048,
DefaultTemperature: 0.5,
DefaultTopK: 3,
DefaultTopP: 0.95,
HarmThreshold: HarmBlockOnlyHigh,
CloudProject: CloudProject,
CloudLocation: CloudLocation,
DefaultModel: DefaultModel,
DefaultEmbeddingModel: DefaultEmbeddingModel,
DefaultCandidateCount: DefaultCandidateCount,
DefaultMaxTokens: DefaultMaxTokens,
DefaultTemperature: DefaultTemperature,
DefaultTopK: DefaultTopK,
DefaultTopP: DefaultTopP,
HarmThreshold: DefaultHarmThreshold,
}
}

Expand Down
3 changes: 1 addition & 2 deletions llms/googleai/palm/palm_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingProjectID = errors.New("missing the GCP Project ID, set it in the GOOGLE_CLOUD_PROJECT environment variable") //nolint:lll
ErrMissingLocation = errors.New("missing the GCP Location, set it in the GOOGLE_CLOUD_LOCATION environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
Expand Down Expand Up @@ -85,7 +84,7 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo
}

if len(embeddings) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
Expand Down
3 changes: 1 addition & 2 deletions llms/huggingface/huggingfacellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
)

var (
ErrEmptyResponse = errors.New("empty response")
ErrMissingToken = errors.New("missing the Hugging Face API token. Set it in the HUGGINGFACEHUB_API_TOKEN environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
)
Expand Down Expand Up @@ -115,7 +114,7 @@ func (o *LLM) CreateEmbedding(
return nil, err
}
if len(embeddings) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
Expand Down
11 changes: 5 additions & 6 deletions llms/huggingface/internal/huggingfaceclient/huggingfaceclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ import (
"context"
"errors"
"fmt"
)

var (
ErrInvalidToken = errors.New("invalid token")
ErrEmptyResponse = errors.New("empty response")
"github.com/tmc/langchaingo/llms"
)

var ErrInvalidToken = errors.New("invalid token")

type Client struct {
Token string
Model string
Expand Down Expand Up @@ -64,7 +63,7 @@ func (c *Client) RunInference(ctx context.Context, request *InferenceRequest) (*
return nil, fmt.Errorf("failed to run inference: %w", err)
}
if len(resp) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}
text := resp[0].Text
// TODO: Add response cleaning based on Model.
Expand Down Expand Up @@ -96,7 +95,7 @@ func (c *Client) CreateEmbedding(
}

if len(resp) == 0 {
return nil, ErrEmptyResponse
return nil, llms.ErrEmptyResponse
}

return resp, nil
Expand Down
Loading
Loading