diff --git a/pkg/models/config.go b/pkg/extmodel/config.go similarity index 86% rename from pkg/models/config.go rename to pkg/extmodel/config.go index edc67ca..2266fca 100644 --- a/pkg/models/config.go +++ b/pkg/extmodel/config.go @@ -1,17 +1,18 @@ -package models +package extmodel import ( "fmt" + "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/routers/latency" "github.com/EinStack/glide/pkg/telemetry" ) // Config defines an extra configuration for a model wrapper around a provider -type Config[P provider.ProviderConfig] struct { +type Config[P providers.Configurer] struct { ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` @@ -22,7 +23,7 @@ type Config[P provider.ProviderConfig] struct { Provider P `yaml:"provider" json:"provider"` } -func NewConfig[P provider.ProviderConfig](ID string) *Config[P] { +func NewConfig[P providers.Configurer](ID string) *Config[P] { config := DefaultConfig[P]() config.ID = ID @@ -30,7 +31,7 @@ func NewConfig[P provider.ProviderConfig](ID string) *Config[P] { return &config } -func DefaultConfig[P provider.ProviderConfig]() Config[P] { +func DefaultConfig[P providers.Configurer]() Config[P] { return Config[P]{ Enabled: true, Client: clients.DefaultClientConfig(), diff --git a/pkg/models/lang.go b/pkg/extmodel/lang.go similarity index 91% rename from pkg/models/lang.go rename to pkg/extmodel/lang.go index 2d50bd3..e3243cb 100644 --- a/pkg/models/lang.go +++ b/pkg/extmodel/lang.go @@ -1,11 +1,11 @@ -package models +package extmodel import ( "context" "io" "time" - "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/clients" health2 "github.com/EinStack/glide/pkg/resiliency/health" @@ -18,7 +18,7 @@ import ( ) type LangModel interface { - Model + Interface Provider() string ModelName() string Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) @@ -32,14 +32,14 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client provider.LangProvider + client providers.LangProvider healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, @@ -170,10 +170,10 @@ func (m *LanguageModel) ModelName() string { return m.client.ModelName() } -func ChatLatency(model Model) *latency.MovingAverage { +func ChatLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatLatency() } -func ChatStreamLatency(model Model) *latency.MovingAverage { +func ChatStreamLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatStreamLatency() } diff --git a/pkg/extmodel/model.go b/pkg/extmodel/model.go new file mode 100644 index 0000000..b250c47 --- /dev/null +++ b/pkg/extmodel/model.go @@ -0,0 +1,11 @@ +package extmodel + +import "github.com/EinStack/glide/pkg/config/fields" + +// Interface represent a configured external modality-agnostic model with its routing properties and status +type Interface interface { + ID() string + Healthy() bool + LatencyUpdateInterval() *fields.Duration + Weight() int +} diff --git a/pkg/providers/testing/models.go b/pkg/extmodel/testing.go similarity index 89% rename from pkg/providers/testing/models.go rename to pkg/extmodel/testing.go index 57500d2..6d51ca7 100644 --- a/pkg/providers/testing/models.go +++ b/pkg/extmodel/testing.go @@ -1,10 +1,9 @@ -package testing +package extmodel import ( "time" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/routers/latency" ) @@ -53,6 +52,6 @@ func (m LangModelMock) Weight() int { return m.weight } -func ChatMockLatency(model models.Model) *latency.MovingAverage { +func ChatMockLatency(model Interface) *latency.MovingAverage { return model.(LangModelMock).chatLatency } diff --git a/pkg/models/model.go b/pkg/models/model.go deleted file mode 100644 index 707efee..0000000 --- a/pkg/models/model.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -import "github.com/EinStack/glide/pkg/config/fields" - -// Model represent a configured external modality-agnostic model with its routing properties and status -type Model interface { - ID() string - Healthy() bool - LatencyUpdateInterval() *fields.Duration - Weight() int -} diff --git a/pkg/provider/config.go b/pkg/provider/config.go deleted file mode 100644 index 0424e83..0000000 --- a/pkg/provider/config.go +++ /dev/null @@ -1,12 +0,0 @@ -package provider - -import ( - "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/telemetry" -) - -// TODO: ProviderConfig should be more generic, not tied to LangProviders -type ProviderConfig interface { - UnmarshalYAML(unmarshal func(interface{}) error) error - ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) -} diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 466f07e..6469710 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/EinStack/glide/pkg/provider" "github.com/go-playground/validator/v10" "gopkg.in/yaml.v3" @@ -22,12 +21,17 @@ func init() { validate = validator.New() } -// TODO: rename DynLangProvider to DynLangProviderConfig -type DynLangProvider map[provider.ProviderID]interface{} +// TODO: Configurer should be more generic, not tied to LangProviders +type Configurer interface { + UnmarshalYAML(unmarshal func(interface{}) error) error + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} + +type Config map[ProviderID]interface{} -var _ provider.ProviderConfig = (*DynLangProvider)(nil) +var _ Configurer = (*Config)(nil) -func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (p Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { for providerID, configValue := range p { if configValue == nil { continue @@ -60,12 +64,12 @@ func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *client return providerConfig.ToClient(tel, clientConfig) } - return nil, provider.ErrProviderNotFound + return nil, ErrProviderNotFound } // validate ensure there is only one provider configured and it's supported by Glide -func (p DynLangProvider) validate() error { - configuredProviders := make([]provider.ProviderID, 0, len(p)) +func (p Config) validate() error { + configuredProviders := make([]ProviderID, 0, len(p)) for providerID, config := range p { if config != nil { @@ -115,8 +119,8 @@ func (p DynLangProvider) validate() error { return providerConfig.UnmarshalYAML(providerConfigUnmarshaller) } -func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain DynLangProvider // to avoid recursion +func (p *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain Config // to avoid recursion temp := plain{} @@ -124,7 +128,7 @@ func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error return err } - *p = DynLangProvider(temp) + *p = Config(temp) return p.validate() } diff --git a/pkg/providers/config_test.go b/pkg/providers/config_test.go index 7bb7c40..7e1d18c 100644 --- a/pkg/providers/config_test.go +++ b/pkg/providers/config_test.go @@ -5,19 +5,18 @@ import ( "path/filepath" "testing" - testprovider "github.com/EinStack/glide/pkg/providers/testing" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) func TestDynLangProvider(t *testing.T) { - LangRegistry.Register(testprovider.ProviderTest, &testprovider.Config{}) + LangRegistry.Register(ProviderTest, &TestConfig{}) type ProviderConfig struct { - Provider *DynLangProvider `yaml:"provider"` + Provider *Config `yaml:"provider"` } - prConfig := make(DynLangProvider) + prConfig := make(Config) providerConfig := ProviderConfig{ Provider: &prConfig, } diff --git a/pkg/provider/provider.go b/pkg/providers/interface.go similarity index 97% rename from pkg/provider/provider.go rename to pkg/providers/interface.go index d2b7641..0b9fe45 100644 --- a/pkg/provider/provider.go +++ b/pkg/providers/interface.go @@ -1,4 +1,4 @@ -package provider +package providers import ( "context" @@ -21,9 +21,7 @@ type ModelProvider interface { // LangProvider defines an interface a provider should fulfill to be able to serve language chat requests type LangProvider interface { ModelProvider - SupportChatStream() bool - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) } @@ -31,8 +29,6 @@ type LangProvider interface { // EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings type EmbeddingProvider interface { ModelProvider - SupportEmbedding() bool - Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) } diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go index 4dcf9ff..fee9a58 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/providers/openai/config.go @@ -3,7 +3,7 @@ package openai import ( "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/telemetry" ) @@ -52,7 +52,7 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } -var _ provider.ProviderConfig = (*Config)(nil) +var _ providers.Configurer = (*Config)(nil) // DefaultConfig for OpenAI models func DefaultConfig() *Config { @@ -66,7 +66,7 @@ func DefaultConfig() *Config { } } -func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (providers.LangProvider, error) { return NewClient(c, clientConfig, tel) } diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go index 3f626f6..8298ebf 100644 --- a/pkg/providers/registry.go +++ b/pkg/providers/registry.go @@ -2,23 +2,21 @@ package providers import ( "fmt" - - "github.com/EinStack/glide/pkg/provider" ) var LangRegistry = NewProviderRegistry() type ProviderRegistry struct { - providers map[provider.ProviderID]provider.ProviderConfig + providers map[ProviderID]Configurer } func NewProviderRegistry() *ProviderRegistry { return &ProviderRegistry{ - providers: make(map[provider.ProviderID]provider.ProviderConfig), + providers: make(map[ProviderID]Configurer), } } -func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.ProviderConfig) { +func (r *ProviderRegistry) Register(name ProviderID, config Configurer) { if _, ok := r.Get(name); ok { panic(fmt.Sprintf("provider %s is already registered", name)) } @@ -26,14 +24,14 @@ func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.Pr r.providers[name] = config } -func (r *ProviderRegistry) Get(name provider.ProviderID) (provider.ProviderConfig, bool) { +func (r *ProviderRegistry) Get(name ProviderID) (Configurer, bool) { config, ok := r.providers[name] return config, ok } -func (r *ProviderRegistry) Available() []provider.ProviderID { - available := make([]provider.ProviderID, 0, len(r.providers)) +func (r *ProviderRegistry) Available() []ProviderID { + available := make([]ProviderID, 0, len(r.providers)) for providerID := range r.providers { available = append(available, providerID) diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing.go similarity index 74% rename from pkg/providers/testing/lang.go rename to pkg/providers/testing.go index 3c27792..f4cda83 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/providers/testing.go @@ -1,14 +1,36 @@ -package testing +package providers import ( "context" "io" - clients2 "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/telemetry" +) + +const ( + ProviderTest = "testprovider" ) +type TestConfig struct { + BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` + ModelName string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` +} + +func (c *TestConfig) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (LangProvider, error) { + return NewProviderMock(nil, []RespMock{}), nil +} + +func (c *TestConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain TestConfig // to avoid recursion + + return unmarshal((*plain)(c)) +} + // RespMock mocks a chat response or a streaming chat chunk type RespMock struct { Msg string @@ -124,7 +146,7 @@ func (c *ProviderMock) SupportChatStream() bool { func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { if c.chatResps == nil { - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } responses := *c.chatResps @@ -139,9 +161,9 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas. return response.Resp(), nil } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { - return nil, clients2.ErrProviderUnavailable + return nil, clients.ErrProviderUnavailable } streams := *c.chatStreams diff --git a/pkg/providers/testing/config.go b/pkg/providers/testing/config.go deleted file mode 100644 index dd7d085..0000000 --- a/pkg/providers/testing/config.go +++ /dev/null @@ -1,29 +0,0 @@ -package testing - -import ( - "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/telemetry" -) - -const ( - ProviderTest = "testprovider" -) - -type Config struct { - BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` - ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` - ModelName string `yaml:"model" json:"model" validate:"required"` - APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` -} - -func (c *Config) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (provider.LangProvider, error) { - return NewProviderMock(nil, []RespMock{}), nil -} - -func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain Config // to avoid recursion - - return unmarshal((*plain)(c)) -} diff --git a/pkg/routers/lang/config.go b/pkg/routers/lang/config.go index a8f3f6b..8817ab5 100644 --- a/pkg/routers/lang/config.go +++ b/pkg/routers/lang/config.go @@ -4,11 +4,12 @@ import ( "fmt" "time" + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/providers" "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -17,7 +18,7 @@ import ( ) type ( - ModelConfig = models.Config[*providers.DynLangProvider] + ModelConfig = extmodel.Config[*providers.Config] ModelPoolConfig = []ModelConfig ) @@ -50,12 +51,12 @@ func NewRouterConfig(RouterID string, opt ...RouterConfigOption) *RouterConfig { } // BuildModels creates LanguageModel slice out of the given config -func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*models.LanguageModel, []*models.LanguageModel, error) { //nolint: cyclop +func (c *RouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) - chatModels := make([]*models.LanguageModel, 0, len(c.Models)) - chatStreamModels := make([]*models.LanguageModel, 0, len(c.Models)) + chatModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) for _, modelConfig := range c.Models { if _, ok := seenIDs[modelConfig.ID]; ok { @@ -159,11 +160,11 @@ func (c *RouterConfig) BuildRetry() *retry.ExpRetry { } func (c *RouterConfig) BuildRouting( - chatModels []*models.LanguageModel, - chatStreamModels []*models.LanguageModel, + chatModels []*extmodel.LanguageModel, + chatStreamModels []*extmodel.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { - chatModelPool := make([]models.Model, 0, len(chatModels)) - chatStreamModelPool := make([]models.Model, 0, len(chatStreamModels)) + chatModelPool := make([]extmodel.Interface, 0, len(chatModels)) + chatStreamModelPool := make([]extmodel.Interface, 0, len(chatStreamModels)) for _, model := range chatModels { chatModelPool = append(chatModelPool, model) @@ -181,8 +182,8 @@ func (c *RouterConfig) BuildRouting( case routing.WeightedRoundRobin: return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil case routing.LeastLatency: - return routing.NewLeastLatencyRouting(models.ChatLatency, chatModelPool), - routing.NewLeastLatencyRouting(models.ChatStreamLatency, chatStreamModelPool), + return routing.NewLeastLatencyRouting(extmodel.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(extmodel.ChatStreamLatency, chatStreamModelPool), nil } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index 1ed4336..1b5975b 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -27,7 +27,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -45,7 +45,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -80,7 +80,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, @@ -93,7 +93,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ cohere.ProviderID: &cohere.Config{ APIKey: "ABC", DefaultParams: &cohereParams, @@ -129,7 +129,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -147,7 +147,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -170,7 +170,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, @@ -183,7 +183,7 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: &providers.DynLangProvider{ + Provider: &providers.Config{ openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, diff --git a/pkg/routers/lang/router.go b/pkg/routers/lang/router.go index caf98a4..9dcc0c3 100644 --- a/pkg/routers/lang/router.go +++ b/pkg/routers/lang/router.go @@ -4,8 +4,9 @@ import ( "context" "errors" + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/models" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/routers/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -19,8 +20,8 @@ type RouterID = string type Router struct { routerID RouterID Config *RouterConfig - chatModels []*models.LanguageModel - chatStreamModels []*models.LanguageModel + chatModels []*extmodel.LanguageModel + chatStreamModels []*extmodel.LanguageModel chatRouting routing.LangModelRouting chatStreamRouting routing.LangModelRouting retry *retry.ExpRetry @@ -76,7 +77,7 @@ func (r *Router) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.C break } - langModel := model.(models.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) @@ -146,7 +147,7 @@ func (r *Router) ChatStream( break } - langModel := model.(models.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) modelRespC, err := langModel.ChatStream(ctx, chatParams) diff --git a/pkg/routers/lang/router_test.go b/pkg/routers/lang/router_test.go index 65c9b2f..ce6c28d 100644 --- a/pkg/routers/lang/router_test.go +++ b/pkg/routers/lang/router_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" + "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/resiliency/health" @@ -24,24 +24,24 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -70,31 +70,31 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "third", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -127,24 +127,24 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { budget := health.NewErrorBudget(1, health.MILLI) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -170,24 +170,24 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { budget := health.NewErrorBudget(1, health.MIN) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + providers.NewProviderMock(nil, []providers.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -215,24 +215,24 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + providers.NewProviderMock(nil, []providers.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), budget, *latConfig, 1, ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -257,11 +257,11 @@ func TestLangRouter_ChatStream(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Msg: "Bill"}, {Msg: "Gates"}, {Msg: "entered"}, @@ -273,10 +273,10 @@ func TestLangRouter_ChatStream(t *testing.T) { *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Msg: "Knock"}, {Msg: "Knock"}, {Msg: "joke"}, @@ -288,7 +288,7 @@ func TestLangRouter_ChatStream(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -332,19 +332,19 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, nil), + providers.NewStreamProviderMock(nil, nil), budget, *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock( - &[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock( + &[]providers.RespMock{ {Msg: "Knock"}, {Msg: "knock"}, {Msg: "joke"}, @@ -357,7 +357,7 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } @@ -401,11 +401,11 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*models.LanguageModel{ - models.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -413,10 +413,10 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { *latConfig, 1, ), - models.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + providers.NewStreamProviderMock(nil, []providers.RespStreamMock{ + providers.NewRespStreamMock(&[]providers.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -426,7 +426,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), } - modelPool := make([]models.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { modelPool = append(modelPool, model) } diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index e6c56a6..d34f45e 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -5,7 +5,7 @@ import ( "sync/atomic" "time" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/EinStack/glide/pkg/routers/latency" ) @@ -15,16 +15,16 @@ const ( ) // LatencyGetter defines where to find latency for the specific model action -type LatencyGetter = func(model models.Model) *latency.MovingAverage +type LatencyGetter = func(model extmodel.Interface) *latency.MovingAverage // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex - model models.Model + model extmodel.Interface expireAt time.Time } -func NewSchedule(model models.Model) *ModelSchedule { +func NewSchedule(model extmodel.Interface) *ModelSchedule { schedule := &ModelSchedule{ model: model, } @@ -67,7 +67,7 @@ type LeastLatencyRouting struct { schedules []*ModelSchedule } -func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []models.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []extmodel.Interface) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator { // other model latencies that might have improved over time). // For that, we introduced expiration time after which the model receives a request // even if it was not the fastest to respond -func (r *LeastLatencyRouting) Next() (models.Model, error) { //nolint:cyclop +func (r *LeastLatencyRouting) Next() (extmodel.Interface, error) { //nolint:cyclop coldSchedules := r.getColdModelSchedules() if len(coldSchedules) > 0 { diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index 2f18b32..0e6618f 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -5,9 +5,7 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -33,13 +31,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, modelPool) + routing := NewLeastLatencyRouting(extmodel.ChatMockLatency, modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -107,7 +105,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { for _, model := range tc.models { schedules = append(schedules, &ModelSchedule{ - model: ptesting.NewLangModelMock( + model: extmodel.NewLangModelMock( model.modelID, model.healthy, model.latency, @@ -118,7 +116,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { } routing := LeastLatencyRouting{ - latencyGetter: ptesting.ChatMockLatency, + latencyGetter: extmodel.ChatMockLatency, schedules: schedules, } @@ -144,13 +142,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { for name, latencies := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(latencies)) + modelPool := make([]extmodel.Interface, 0, len(latencies)) for idx, latency := range latencies { - modelPool = append(modelPool, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(models.ChatLatency, modelPool) + routing := NewLeastLatencyRouting(extmodel.ChatLatency, modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority.go b/pkg/routers/routing/priority.go index 04d4d94..7cf5cee 100644 --- a/pkg/routers/routing/priority.go +++ b/pkg/routers/routing/priority.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -15,10 +15,10 @@ const ( // Priority of models are defined as position of the model on the list // (e.g. the first model definition has the highest priority, then the second model definition and so on) type PriorityRouting struct { - models []models.Model + models []extmodel.Interface } -func NewPriority(models []models.Model) *PriorityRouting { +func NewPriority(models []extmodel.Interface) *PriorityRouting { return &PriorityRouting{ models: models, } @@ -35,10 +35,10 @@ func (r *PriorityRouting) Iterator() LangModelIterator { type PriorityIterator struct { idx *atomic.Uint64 - models []models.Model + models []extmodel.Interface } -func (r PriorityIterator) Next() (models.Model, error) { +func (r PriorityIterator) Next() (extmodel.Interface, error) { modelPool := r.models for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) { diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index c0713f3..98e27e7 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -29,10 +27,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewPriority(modelPool) @@ -49,10 +47,10 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { } func TestPriorityRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } routing := NewPriority(modelPool) diff --git a/pkg/routers/routing/round_robin.go b/pkg/routers/routing/round_robin.go index abd2ff9..7582cbc 100644 --- a/pkg/routers/routing/round_robin.go +++ b/pkg/routers/routing/round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -13,10 +13,10 @@ const ( // RoundRobinRouting routes request to the next model in the list in cycle type RoundRobinRouting struct { idx atomic.Uint64 - models []models.Model + models []extmodel.Interface } -func NewRoundRobinRouting(models []models.Model) *RoundRobinRouting { +func NewRoundRobinRouting(models []extmodel.Interface) *RoundRobinRouting { return &RoundRobinRouting{ models: models, } @@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *RoundRobinRouting) Next() (models.Model, error) { +func (r *RoundRobinRouting) Next() (extmodel.Interface, error) { modelLen := len(r.models) // in order to avoid infinite loop in case of no healthy model is available, diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index c2c6d30..7287f46 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -30,10 +28,10 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewRoundRobinRouting(modelPool) @@ -52,10 +50,10 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { } func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } routing := NewRoundRobinRouting(modelPool) diff --git a/pkg/routers/routing/strategies.go b/pkg/routers/routing/strategies.go index 960702a..48d18ab 100644 --- a/pkg/routers/routing/strategies.go +++ b/pkg/routers/routing/strategies.go @@ -3,7 +3,7 @@ package routing import ( "errors" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) var ErrNoHealthyModels = errors.New("no healthy models found") @@ -16,5 +16,5 @@ type LangModelRouting interface { } type LangModelIterator interface { - Next() (models.Model, error) + Next() (extmodel.Interface, error) } diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/routers/routing/weighted_round_robin.go index dfbee41..418add9 100644 --- a/pkg/routers/routing/weighted_round_robin.go +++ b/pkg/routers/routing/weighted_round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync" - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -11,7 +11,7 @@ const ( ) type Weighter struct { - model models.Model + model extmodel.Interface currentWeight int } @@ -36,7 +36,7 @@ type WRoundRobinRouting struct { weights []*Weighter } -func NewWeightedRoundRobin(models []models.Model) *WRoundRobinRouting { +func NewWeightedRoundRobin(models []extmodel.Interface) *WRoundRobinRouting { weights := make([]*Weighter, 0, len(models)) for _, model := range models { @@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *WRoundRobinRouting) Next() (models.Model, error) { +func (r *WRoundRobinRouting) Next() (extmodel.Interface, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index e24d4e8..7ec9b24 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/models" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -116,10 +114,10 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - modelPool := make([]models.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - modelPool = append(modelPool, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } routing := NewWeightedRoundRobin(modelPool) @@ -142,10 +140,10 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { } func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { - modelPool := []models.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 2), - ptesting.NewLangModelMock("third", false, 0, 3), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 2), + extmodel.NewLangModelMock("third", false, 0, 3), } routing := NewWeightedRoundRobin(modelPool)