From f872d721a221df7891da1abb22169e621d5b49ba Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 12 Aug 2024 21:43:25 +0300 Subject: [PATCH] #67: Moved lang & embed router into the routers package & ensured LangProvider interface --- pkg/api/http/handlers.go | 10 ++-- pkg/api/http/server.go | 6 +- pkg/api/servers.go | 4 +- pkg/config/config.go | 8 +-- pkg/gateway.go | 4 +- pkg/provider/anthropic/client.go | 7 +++ pkg/provider/azureopenai/client.go | 7 +++ pkg/provider/bedrock/client.go | 7 +++ pkg/provider/cohere/client.go | 7 +++ pkg/provider/octoml/client.go | 7 +++ pkg/provider/ollama/client.go | 7 +++ pkg/provider/openai/client.go | 7 +++ pkg/provider/testing.go | 5 ++ pkg/routers/config.go | 7 +++ pkg/routers/embed/config.go | 10 ---- pkg/routers/embed_config.go | 16 +++++ .../{embed/router.go => embed_router.go} | 2 +- .../{lang/config.go => lang_config.go} | 60 +++++++++---------- .../config_test.go => lang_config_test.go} | 26 ++++---- .../{lang/router.go => lang_router.go} | 16 ++--- .../router_test.go => lang_router_test.go} | 18 +++--- pkg/routers/{manager => }/manager.go | 17 +++--- pkg/routers/manager/config.go | 9 --- 23 files changed, 161 insertions(+), 106 deletions(-) delete mode 100644 pkg/routers/embed/config.go create mode 100644 pkg/routers/embed_config.go rename pkg/routers/{embed/router.go => embed_router.go} (94%) rename pkg/routers/{lang/config.go => lang_config.go} (80%) rename pkg/routers/{lang/config_test.go => lang_config_test.go} (92%) rename pkg/routers/{lang/router.go => lang_router.go} (93%) rename pkg/routers/{lang/router_test.go => lang_router_test.go} (98%) rename pkg/routers/{manager => }/manager.go (64%) delete mode 100644 pkg/routers/manager/config.go diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index cc2ac3d..58c727c 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" @@ -32,7 +32,7 @@ type Handler = func(c *fiber.Ctx) error // @Failure 400 {object} schemas.Error // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chat [POST] -func LangChatHandler(routerManager *manager.RouterManager) Handler { +func LangChatHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) @@ -73,7 +73,7 @@ func LangChatHandler(routerManager *manager.RouterManager) Handler { } } -func LangStreamRouterValidator(routerManager *manager.RouterManager) Handler { +func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { routerID := c.Params("router") @@ -108,7 +108,7 @@ func LangStreamRouterValidator(routerManager *manager.RouterManager) Handler { // @Failure 426 // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chatStream [GET] -func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *manager.RouterManager) Handler { +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket return websocket.New(func(c *websocket.Conn) { routerID := c.Params("router") @@ -176,7 +176,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *manager.Rout // @Produce json // @Success 200 {object} schemas.RouterListSchema // @Router /v1/language/ [GET] -func LangRoutersHandler(routerManager *manager.RouterManager) Handler { +func LangRoutersHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index 3589996..1242830 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/gofiber/contrib/otelfiber" @@ -25,11 +25,11 @@ import ( type Server struct { config *ServerConfig telemetry *telemetry.Telemetry - routerManager *manager.RouterManager + routerManager *routers.RouterManager server *fiber.App } -func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *manager.RouterManager) (*Server, error) { +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { srv := config.ToServer() return &Server{ diff --git a/pkg/api/servers.go b/pkg/api/servers.go index da2d130..4ce8b37 100644 --- a/pkg/api/servers.go +++ b/pkg/api/servers.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "go.uber.org/zap" @@ -19,7 +19,7 @@ type ServerManager struct { telemetry *telemetry.Telemetry } -func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *manager.RouterManager) (*ServerManager, error) { +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { httpServer, err := http.NewServer(cfg.HTTP, tel, router) if err != nil { return nil, err diff --git a/pkg/config/config.go b/pkg/config/config.go index cacdc2a..cd99540 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,15 +2,15 @@ package config import ( "github.com/EinStack/glide/pkg/api" - routerconfig "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/telemetry" ) // Config is a general top-level Glide configuration type Config struct { - Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` - API *api.Config `yaml:"api" validate:"required"` - Routers routerconfig.Config `yaml:"routers" validate:"required"` + Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` + API *api.Config `yaml:"api" validate:"required"` + Routers routers.RoutersConfig `yaml:"routers" validate:"required"` } func DefaultConfig() *Config { diff --git a/pkg/gateway.go b/pkg/gateway.go index a3c8969..b3ce904 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -7,7 +7,7 @@ import ( "os/signal" "syscall" - "github.com/EinStack/glide/pkg/routers/manager" + "github.com/EinStack/glide/pkg/routers" "github.com/EinStack/glide/pkg/version" "go.opentelemetry.io/contrib/instrumentation/host" @@ -50,7 +50,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) tel.L().Debug("✅ Config loaded successfully:\n" + configProvider.GetStr()) - routerManager, err := manager.NewManager(&cfg.Routers, tel) + routerManager, err := routers.NewManager(&cfg.Routers, tel) if err != nil { return nil, err } diff --git a/pkg/provider/anthropic/client.go b/pkg/provider/anthropic/client.go index ce697db..2e08b2e 100644 --- a/pkg/provider/anthropic/client.go +++ b/pkg/provider/anthropic/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/azureopenai/client.go b/pkg/provider/azureopenai/client.go index 5c34a15..6ec9046 100644 --- a/pkg/provider/azureopenai/client.go +++ b/pkg/provider/azureopenai/client.go @@ -5,6 +5,8 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/provider/openai" @@ -28,6 +30,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Azure OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL := fmt.Sprintf( diff --git a/pkg/provider/bedrock/client.go b/pkg/provider/bedrock/client.go index 673cb49..aa3905f 100644 --- a/pkg/provider/bedrock/client.go +++ b/pkg/provider/bedrock/client.go @@ -7,6 +7,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -36,6 +38,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint, providerConfig.ModelName, "/invoke") diff --git a/pkg/provider/cohere/client.go b/pkg/provider/cohere/client.go index 3393e01..c3e43eb 100644 --- a/pkg/provider/cohere/client.go +++ b/pkg/provider/cohere/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Cohere client for the Cohere API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/octoml/client.go b/pkg/provider/octoml/client.go index 420a991..30ab779 100644 --- a/pkg/provider/octoml/client.go +++ b/pkg/provider/octoml/client.go @@ -6,6 +6,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -31,6 +33,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OctoML client for the OctoML API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/ollama/client.go b/pkg/provider/ollama/client.go index 85192b6..df624cd 100644 --- a/pkg/provider/ollama/client.go +++ b/pkg/provider/ollama/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -24,6 +26,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/openai/client.go b/pkg/provider/openai/client.go index 30a0438..8567e26 100644 --- a/pkg/provider/openai/client.go +++ b/pkg/provider/openai/client.go @@ -5,6 +5,8 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/clients" "go.uber.org/zap" @@ -29,6 +31,11 @@ type Client struct { logger *zap.Logger } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/provider/testing.go b/pkg/provider/testing.go index ca9a00d..7213334 100644 --- a/pkg/provider/testing.go +++ b/pkg/provider/testing.go @@ -122,6 +122,11 @@ type Mock struct { modelName *string } +// ensure interfaces +var ( + _ LangProvider = (*Mock)(nil) +) + func NewMock(modelName *string, responses []RespMock) *Mock { return &Mock{ idx: 0, diff --git a/pkg/routers/config.go b/pkg/routers/config.go index a3c8f69..6d4610a 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -22,3 +22,10 @@ func DefaultConfig() RouterConfig { Retry: retry.DefaultExpRetryConfig(), } } + +// RoutersConfig defines a config for a set of supported router types +// TODO: remove nolint after renaming the package +type RoutersConfig struct { //nolint: revive + LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers + // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` +} diff --git a/pkg/routers/embed/config.go b/pkg/routers/embed/config.go deleted file mode 100644 index 49f4821..0000000 --- a/pkg/routers/embed/config.go +++ /dev/null @@ -1,10 +0,0 @@ -package embed - -import ( - "github.com/EinStack/glide/pkg/routers" -) - -type EmbeddingRouterConfig struct { - routers.RouterConfig - // Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests -} diff --git a/pkg/routers/embed_config.go b/pkg/routers/embed_config.go new file mode 100644 index 0000000..a93e937 --- /dev/null +++ b/pkg/routers/embed_config.go @@ -0,0 +1,16 @@ +package routers + +import ( + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/provider" +) + +type ( + EmbedModelConfig = extmodel.Config[*provider.Config] + EmbedModelPoolConfig = []EmbedModelConfig +) + +type EmbeddingRouterConfig struct { + RouterConfig + Models EmbedModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} diff --git a/pkg/routers/embed/router.go b/pkg/routers/embed_router.go similarity index 94% rename from pkg/routers/embed/router.go rename to pkg/routers/embed_router.go index 9068537..ef81d59 100644 --- a/pkg/routers/embed/router.go +++ b/pkg/routers/embed_router.go @@ -1,4 +1,4 @@ -package embed +package routers type EmbeddingRouter struct { // routerID lang.RouterID diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang_config.go similarity index 80% rename from pkg/routers/lang/config.go rename to pkg/routers/lang_config.go index eb4a56c..9a7b685 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang_config.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "fmt" @@ -8,8 +8,6 @@ import ( "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -18,40 +16,40 @@ import ( ) type ( - ModelConfig = extmodel.Config[*provider.Config] - ModelPoolConfig = []ModelConfig + LangModelConfig = extmodel.Config[*provider.Config] + LangModelPoolConfig = []LangModelConfig ) -// RouterConfig -type RouterConfig struct { - routers.RouterConfig - Models ModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +// LangRouterConfig +type LangRouterConfig struct { + RouterConfig + Models LangModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests } -type RouterConfigOption = func(*RouterConfig) +type RouterConfigOption = func(*LangRouterConfig) -func WithModels(models ModelPoolConfig) RouterConfigOption { - return func(c *RouterConfig) { +func WithModels(models LangModelPoolConfig) RouterConfigOption { + return func(c *LangRouterConfig) { c.Models = models } } -func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { - config := &RouterConfig{ - RouterConfig: routers.DefaultConfig(), +func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *LangRouterConfig { + cfg := &LangRouterConfig{ + RouterConfig: DefaultConfig(), } - config.ID = RouterID + cfg.ID = RouterID for _, o := range opt { - o(config) + o(cfg) } - return config + return cfg } // BuildModels creates LanguageModel slice out of the given config -func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) @@ -147,7 +145,7 @@ func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.Langua return chatModels, chatStreamModels, nil } -func (c *RouterConfig) BuildRetry() *retry.ExpRetry { +func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { retryConfig := c.Retry maxDelay := time.Duration(*retryConfig.MaxDelay) @@ -159,7 +157,7 @@ func (c *RouterConfig) BuildRetry() *retry.ExpRetry { ) } -func (c *RouterConfig) BuildRouting( +func (c *LangRouterConfig) BuildRouting( chatModels []*extmodel.LanguageModel, chatStreamModels []*extmodel.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { @@ -190,25 +188,25 @@ func (c *RouterConfig) BuildRouting( return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) } -func DefaultRouterConfig() *RouterConfig { - return &RouterConfig{ - RouterConfig: routers.DefaultConfig(), +func DefaultRouterConfig() *LangRouterConfig { + return &LangRouterConfig{ + RouterConfig: DefaultConfig(), } } -func (c *RouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = *DefaultRouterConfig() +func (c LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + c = *DefaultRouterConfig() - type plain RouterConfig // to avoid recursion + type plain LangRouterConfig // to avoid recursion - return unmarshal((*plain)(c)) + return unmarshal((plain)(c)) } -type RoutersConfig []RouterConfig +type LangRoutersConfig []LangRouterConfig -func (c RoutersConfig) Build(tel *telemetry.Telemetry) ([]*Router, error) { +func (c LangRoutersConfig) Build(tel *telemetry.Telemetry) ([]*LangRouter, error) { seenIDs := make(map[string]bool, len(c)) - langRouters := make([]*Router, 0, len(c)) + langRouters := make([]*LangRouter, 0, len(c)) var errs error diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang_config_test.go similarity index 92% rename from pkg/routers/lang/config_test.go rename to pkg/routers/lang_config_test.go index 3f36c70..81998f5 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang_config_test.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "testing" @@ -18,10 +18,10 @@ import ( func TestRouterConfig_BuildModels(t *testing.T) { defaultParams := openai.DefaultParams() - cfg := RoutersConfig{ + cfg := LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -39,7 +39,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ), *NewRouterConfig( "second_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -74,7 +74,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { cfg := NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -116,14 +116,14 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { tests := []struct { name string - config RoutersConfig + config LangRoutersConfig }{ { "duplicated router IDs", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -141,7 +141,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { ), *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -161,10 +161,10 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "duplicated model IDs", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{ + WithModels(LangModelPoolConfig{ { ID: "first_model", Enabled: true, @@ -197,10 +197,10 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { }, { "no models", - RoutersConfig{ + LangRoutersConfig{ *NewRouterConfig( "first_router", - WithModels(ModelPoolConfig{}), + WithModels(LangModelPoolConfig{}), ), }, }, diff --git a/pkg/routers/lang/router.go b/pkg/routers/lang_router.go similarity index 93% rename from pkg/routers/lang/router.go rename to pkg/routers/lang_router.go index 9dcc0c3..fbae5f2 100644 --- a/pkg/routers/lang/router.go +++ b/pkg/routers/lang_router.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "context" @@ -17,9 +17,9 @@ var ErrNoModels = errors.New("no models configured for router") type RouterID = string -type Router struct { +type LangRouter struct { routerID RouterID - Config *RouterConfig + Config *LangRouterConfig chatModels []*extmodel.LanguageModel chatStreamModels []*extmodel.LanguageModel chatRouting routing.LangModelRouting @@ -29,7 +29,7 @@ type Router struct { logger *zap.Logger } -func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) { +func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { chatModels, chatStreamModels, err := cfg.BuildModels(tel) if err != nil { return nil, err @@ -40,7 +40,7 @@ func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) return nil, err } - router := &Router{ + router := &LangRouter{ routerID: cfg.ID, Config: cfg, chatModels: chatModels, @@ -55,11 +55,11 @@ func NewLangRouter(cfg *RouterConfig, tel *telemetry.Telemetry) (*Router, error) return router, err } -func (r *Router) ID() RouterID { +func (r *LangRouter) ID() RouterID { return r.routerID } -func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -115,7 +115,7 @@ func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.C return nil, &schemas.ErrNoModelAvailable } -func (r *Router) ChatStream( +func (r *LangRouter) ChatStream( ctx context.Context, req *schemas.ChatStreamRequest, respC chan<- *schemas.ChatStreamMessage, diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang_router_test.go similarity index 98% rename from pkg/routers/lang/router_test.go rename to pkg/routers/lang_router_test.go index 2b71928..671b075 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang_router_test.go @@ -1,4 +1,4 @@ -package lang +package routers import ( "context" @@ -46,7 +46,7 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -101,7 +101,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { expectedModels := []string{"third", "third"} - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -149,7 +149,7 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -192,7 +192,7 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -237,7 +237,7 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), @@ -293,7 +293,7 @@ func TestLangRouter_ChatStream(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_stream_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -362,7 +362,7 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_stream_router", retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), chatRouting: routing.NewPriority(modelPool), @@ -431,7 +431,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { modelPool = append(modelPool, model) } - router := Router{ + router := LangRouter{ routerID: "test_router", retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), chatRouting: routing.NewPriority(modelPool), diff --git a/pkg/routers/manager/manager.go b/pkg/routers/manager.go similarity index 64% rename from pkg/routers/manager/manager.go rename to pkg/routers/manager.go index add7201..f719d09 100644 --- a/pkg/routers/manager/manager.go +++ b/pkg/routers/manager.go @@ -1,26 +1,25 @@ -package manager +package routers import ( "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers/lang" "github.com/EinStack/glide/pkg/telemetry" ) type RouterManager struct { - Config *Config + Config *RoutersConfig tel *telemetry.Telemetry - langRouterMap *map[string]*lang.Router - langRouters []*lang.Router + langRouterMap *map[string]*LangRouter + langRouters []*LangRouter } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers -func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { +func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*RouterManager, error) { langRouters, err := cfg.LanguageRouters.Build(tel) if err != nil { return nil, err } - langRouterMap := make(map[string]*lang.Router, len(langRouters)) + langRouterMap := make(map[string]*LangRouter, len(langRouters)) for _, router := range langRouters { langRouterMap[router.ID()] = router @@ -36,12 +35,12 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return &manager, err } -func (r *RouterManager) GetLangRouters() []*lang.Router { +func (r *RouterManager) GetLangRouters() []*LangRouter { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*lang.Router, error) { +func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } diff --git a/pkg/routers/manager/config.go b/pkg/routers/manager/config.go deleted file mode 100644 index aaaeac0..0000000 --- a/pkg/routers/manager/config.go +++ /dev/null @@ -1,9 +0,0 @@ -package manager - -import "github.com/EinStack/glide/pkg/routers/lang" - -// Config defines a config for a set of supported router types -type Config struct { - LanguageRouters lang.RoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers - // EmbeddingRouters []EmbeddingRouterConfig `yaml:"embedding" validate:"required,dive"` -}