Skip to content

Commit

Permalink
#67: Got rid of provider package & did other layout restructuring to …
Browse files Browse the repository at this point in the history
…fix circular dependency issues
  • Loading branch information
roma-glushko committed Aug 8, 2024
1 parent 339e3f9 commit b662b1e
Show file tree
Hide file tree
Showing 26 changed files with 217 additions and 245 deletions.
11 changes: 6 additions & 5 deletions pkg/models/config.go → pkg/extmodel/config.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package models
package extmodel

import (
"fmt"

"github.com/EinStack/glide/pkg/providers"

"github.com/EinStack/glide/pkg/clients"
"github.com/EinStack/glide/pkg/provider"
"github.com/EinStack/glide/pkg/resiliency/health"
"github.com/EinStack/glide/pkg/routers/latency"
"github.com/EinStack/glide/pkg/telemetry"
)

// Config defines an extra configuration for a model wrapper around a provider
type Config[P provider.ProviderConfig] struct {
type Config[P providers.Configurer] struct {
ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router)
Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled?
ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"`
Expand All @@ -22,15 +23,15 @@ type Config[P provider.ProviderConfig] struct {
Provider P `yaml:"provider" json:"provider"`
}

func NewConfig[P provider.ProviderConfig](ID string) *Config[P] {
func NewConfig[P providers.Configurer](ID string) *Config[P] {
config := DefaultConfig[P]()

config.ID = ID

return &config
}

func DefaultConfig[P provider.ProviderConfig]() Config[P] {
func DefaultConfig[P providers.Configurer]() Config[P] {
return Config[P]{
Enabled: true,
Client: clients.DefaultClientConfig(),
Expand Down
14 changes: 7 additions & 7 deletions pkg/models/lang.go → pkg/extmodel/lang.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package models
package extmodel

import (
"context"
"io"
"time"

"github.com/EinStack/glide/pkg/provider"
"github.com/EinStack/glide/pkg/providers"

"github.com/EinStack/glide/pkg/clients"
health2 "github.com/EinStack/glide/pkg/resiliency/health"
Expand All @@ -18,7 +18,7 @@ import (
)

type LangModel interface {
Model
Interface
Provider() string
ModelName() string
Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)
Expand All @@ -32,14 +32,14 @@ type LangModel interface {
type LanguageModel struct {
modelID string
weight int
client provider.LangProvider
client providers.LangProvider
healthTracker *health2.Tracker
chatLatency *latency.MovingAverage
chatStreamLatency *latency.MovingAverage
latencyUpdateInterval *fields.Duration
}

func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel {
func NewLangModel(modelID string, client providers.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel {
return &LanguageModel{
modelID: modelID,
client: client,
Expand Down Expand Up @@ -170,10 +170,10 @@ func (m *LanguageModel) ModelName() string {
return m.client.ModelName()
}

func ChatLatency(model Model) *latency.MovingAverage {
func ChatLatency(model Interface) *latency.MovingAverage {
return model.(LanguageModel).ChatLatency()
}

func ChatStreamLatency(model Model) *latency.MovingAverage {
func ChatStreamLatency(model Interface) *latency.MovingAverage {
return model.(LanguageModel).ChatStreamLatency()
}
11 changes: 11 additions & 0 deletions pkg/extmodel/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package extmodel

import "github.com/EinStack/glide/pkg/config/fields"

// Interface represent a configured external modality-agnostic model with its routing properties and status
type Interface interface {
ID() string
Healthy() bool
LatencyUpdateInterval() *fields.Duration
Weight() int
}
5 changes: 2 additions & 3 deletions pkg/providers/testing/models.go → pkg/extmodel/testing.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package testing
package extmodel

import (
"time"

"github.com/EinStack/glide/pkg/config/fields"
"github.com/EinStack/glide/pkg/models"
"github.com/EinStack/glide/pkg/routers/latency"
)

Expand Down Expand Up @@ -53,6 +52,6 @@ func (m LangModelMock) Weight() int {
return m.weight
}

func ChatMockLatency(model models.Model) *latency.MovingAverage {
func ChatMockLatency(model Interface) *latency.MovingAverage {
return model.(LangModelMock).chatLatency
}
11 changes: 0 additions & 11 deletions pkg/models/model.go

This file was deleted.

12 changes: 0 additions & 12 deletions pkg/provider/config.go

This file was deleted.

26 changes: 15 additions & 11 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"strings"

"github.com/EinStack/glide/pkg/provider"
"github.com/go-playground/validator/v10"

"gopkg.in/yaml.v3"
Expand All @@ -22,12 +21,17 @@ func init() {
validate = validator.New()
}

// TODO: rename DynLangProvider to DynLangProviderConfig
type DynLangProvider map[provider.ProviderID]interface{}
// TODO: Configurer should be more generic, not tied to LangProviders
type Configurer interface {
UnmarshalYAML(unmarshal func(interface{}) error) error
ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error)
}

type Config map[ProviderID]interface{}

var _ provider.ProviderConfig = (*DynLangProvider)(nil)
var _ Configurer = (*Config)(nil)

func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) {
func (p Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) {
for providerID, configValue := range p {
if configValue == nil {
continue
Expand Down Expand Up @@ -60,12 +64,12 @@ func (p DynLangProvider) ToClient(tel *telemetry.Telemetry, clientConfig *client
return providerConfig.ToClient(tel, clientConfig)
}

return nil, provider.ErrProviderNotFound
return nil, ErrProviderNotFound
}

// validate ensure there is only one provider configured and it's supported by Glide
func (p DynLangProvider) validate() error {
configuredProviders := make([]provider.ProviderID, 0, len(p))
func (p Config) validate() error {
configuredProviders := make([]ProviderID, 0, len(p))

for providerID, config := range p {
if config != nil {
Expand Down Expand Up @@ -115,16 +119,16 @@ func (p DynLangProvider) validate() error {
return providerConfig.UnmarshalYAML(providerConfigUnmarshaller)
}

func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain DynLangProvider // to avoid recursion
func (p *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain Config // to avoid recursion

temp := plain{}

if err := unmarshal(&temp); err != nil {
return err
}

*p = DynLangProvider(temp)
*p = Config(temp)

return p.validate()
}
7 changes: 3 additions & 4 deletions pkg/providers/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@ import (
"path/filepath"
"testing"

testprovider "github.com/EinStack/glide/pkg/providers/testing"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)

func TestDynLangProvider(t *testing.T) {
LangRegistry.Register(testprovider.ProviderTest, &testprovider.Config{})
LangRegistry.Register(ProviderTest, &TestConfig{})

type ProviderConfig struct {
Provider *DynLangProvider `yaml:"provider"`
Provider *Config `yaml:"provider"`
}

prConfig := make(DynLangProvider)
prConfig := make(Config)
providerConfig := ProviderConfig{
Provider: &prConfig,
}
Expand Down
6 changes: 1 addition & 5 deletions pkg/provider/provider.go → pkg/providers/interface.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package provider
package providers

import (
"context"
Expand All @@ -21,18 +21,14 @@ type ModelProvider interface {
// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests
type LangProvider interface {
ModelProvider

SupportChatStream() bool

Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)
ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error)
}

// EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings
type EmbeddingProvider interface {
ModelProvider

SupportEmbedding() bool

Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error)
}
6 changes: 3 additions & 3 deletions pkg/providers/openai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package openai
import (
"github.com/EinStack/glide/pkg/clients"
"github.com/EinStack/glide/pkg/config/fields"
"github.com/EinStack/glide/pkg/provider"
"github.com/EinStack/glide/pkg/providers"
"github.com/EinStack/glide/pkg/telemetry"
)

Expand Down Expand Up @@ -52,7 +52,7 @@ type Config struct {
DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"`
}

var _ provider.ProviderConfig = (*Config)(nil)
var _ providers.Configurer = (*Config)(nil)

// DefaultConfig for OpenAI models
func DefaultConfig() *Config {
Expand All @@ -66,7 +66,7 @@ func DefaultConfig() *Config {
}
}

func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) {
func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (providers.LangProvider, error) {
return NewClient(c, clientConfig, tel)
}

Expand Down
14 changes: 6 additions & 8 deletions pkg/providers/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,36 @@ package providers

import (
"fmt"

"github.com/EinStack/glide/pkg/provider"
)

var LangRegistry = NewProviderRegistry()

type ProviderRegistry struct {
providers map[provider.ProviderID]provider.ProviderConfig
providers map[ProviderID]Configurer
}

func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
providers: make(map[provider.ProviderID]provider.ProviderConfig),
providers: make(map[ProviderID]Configurer),
}
}

func (r *ProviderRegistry) Register(name provider.ProviderID, config provider.ProviderConfig) {
func (r *ProviderRegistry) Register(name ProviderID, config Configurer) {
if _, ok := r.Get(name); ok {
panic(fmt.Sprintf("provider %s is already registered", name))
}

r.providers[name] = config
}

func (r *ProviderRegistry) Get(name provider.ProviderID) (provider.ProviderConfig, bool) {
func (r *ProviderRegistry) Get(name ProviderID) (Configurer, bool) {
config, ok := r.providers[name]

return config, ok
}

func (r *ProviderRegistry) Available() []provider.ProviderID {
available := make([]provider.ProviderID, 0, len(r.providers))
func (r *ProviderRegistry) Available() []ProviderID {
available := make([]ProviderID, 0, len(r.providers))

for providerID := range r.providers {
available = append(available, providerID)
Expand Down
34 changes: 28 additions & 6 deletions pkg/providers/testing/lang.go → pkg/providers/testing.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
package testing
package providers

import (
"context"
"io"

clients2 "github.com/EinStack/glide/pkg/clients"

"github.com/EinStack/glide/pkg/api/schemas"
"github.com/EinStack/glide/pkg/clients"
"github.com/EinStack/glide/pkg/config/fields"
"github.com/EinStack/glide/pkg/telemetry"
)

const (
ProviderTest = "testprovider"
)

type TestConfig struct {
BaseURL string `yaml:"base_url" json:"base_url" validate:"required"`
ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"`
ModelName string `yaml:"model" json:"model" validate:"required"`
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
}

func (c *TestConfig) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (LangProvider, error) {
return NewProviderMock(nil, []RespMock{}), nil
}

func (c *TestConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TestConfig // to avoid recursion

return unmarshal((*plain)(c))
}

// RespMock mocks a chat response or a streaming chat chunk
type RespMock struct {
Msg string
Expand Down Expand Up @@ -124,7 +146,7 @@ func (c *ProviderMock) SupportChatStream() bool {

func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) {
if c.chatResps == nil {
return nil, clients2.ErrProviderUnavailable
return nil, clients.ErrProviderUnavailable
}

responses := *c.chatResps
Expand All @@ -139,9 +161,9 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.
return response.Resp(), nil
}

func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) {
func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
if c.chatStreams == nil || c.idx >= len(*c.chatStreams) {
return nil, clients2.ErrProviderUnavailable
return nil, clients.ErrProviderUnavailable
}

streams := *c.chatStreams
Expand Down
Loading

0 comments on commit b662b1e

Please sign in to comment.