Skip to content

Commit

Permalink
experimental service for custom loading rules
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-gray101 committed Apr 2, 2024
1 parent fbdf3f1 commit 15e6f93
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 35 deletions.
40 changes: 21 additions & 19 deletions core/services/backend_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,10 @@ import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)

type BackendMonitor struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}

func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: appConfig,
}
}

func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) {
config, exists := bm.configLoader.GetBackendConfig(modelName)
// This utility extension is used for backend_monitor and backend_rules, but nowhere outside of service.
// trying out this style - TODO is it better or worse
func getModelLoaderIDFromModelName(bcl *config.BackendConfigLoader, modelName string) (string, config.BackendConfig, error) {
config, exists := bcl.GetBackendConfig(modelName)
var backendId string
if exists {
backendId = config.Model
Expand All @@ -43,7 +31,21 @@ func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string
backendId = fmt.Sprintf("%s.bin", backendId)
}

return backendId, nil
return backendId, config, nil
}

type BackendMonitor struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name.
}

func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor {
return BackendMonitor{
configLoader: configLoader,
modelLoader: modelLoader,
options: appConfig,
}
}

func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.BackendMonitorResponse, error) {
Expand Down Expand Up @@ -102,7 +104,7 @@ func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*schema.Backe
}

func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
backendId, _, err := getModelLoaderIDFromModelName(bm.configLoader, modelName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -132,7 +134,7 @@ func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse
}

func (bm BackendMonitor) ShutdownModel(modelName string) error {
backendId, err := bm.getModelLoaderIDFromModelName(modelName)
backendId, _, err := getModelLoaderIDFromModelName(bm.configLoader, modelName)
if err != nil {
return err
}
Expand Down
118 changes: 118 additions & 0 deletions core/services/backend_rules.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package services

import (
"os"
"path"

"github.com/go-skynet/LocalAI/core/config"
"github.com/go-skynet/LocalAI/pkg/model"
"github.com/hyperjumptech/grule-rule-engine/ast"
"github.com/hyperjumptech/grule-rule-engine/builder"
"github.com/hyperjumptech/grule-rule-engine/engine"
"github.com/hyperjumptech/grule-rule-engine/pkg"
)

const ruleBasedBackendServiceKLName = "RuleBasedBackendService"
const ruleBasedBackendServiceKLVersion = "0.0.1"

type RuleBasedBackendResult struct {
Action string
ModelName string
// TODO other?
}

type ruleBasedBackendResultActionDefinitionsStruct struct {
Blank string
Continue string
Error string
Enqueue string
}

var ruleBasedBackendResultActionDefinitions := ruleBasedBackendResultActionDefinitionsStruct{
Blank: "",
Continue: "continue",
Error: "error",
Enqueue: "enqueue",
}

type RuleBasedBackendService struct {
configLoader *config.BackendConfigLoader
modelLoader *model.ModelLoader
appConfig *config.ApplicationConfig
knowledgeLibrary *ast.KnowledgeLibrary
}

func NewRuleBasedBackendService(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) RuleBasedBackendService {
rbbs := RuleBasedBackendService{
configLoader: configLoader,
modelLoader: modelLoader,
appConfig: appConfig,
}

// TODO: Phase 2 is to have bundled rule sets for common scenarios, such as always allow, SINGLE_BACKEND, only allowing authorized requests to load new backends, etc
// For now, no settings for that, always use a custom json file for testing.
res, err := rbbs.getExternalRuleFileResource()
if err != nil {
rbbs.ReloadRules(res)
}

return rbbs
}

func (rbbs RuleBasedBackendService) getExternalRuleFileResource() (pkg.Resource, error) {
ruleFilePath := path.Join(rbbs.appConfig.ConfigsDir, "backend_rules.json")
f, err := os.Open(ruleFilePath)
if err != nil {
return nil, err
}
underlying := pkg.NewReaderResource(f)
resource, err := pkg.NewJSONResourceFromResource(underlying)
if err != nil {
return nil, err
}
return resource, nil
}

func (rbbs RuleBasedBackendService) ReloadRules(res pkg.Resource) error {
rbbs.knowledgeLibrary = ast.NewKnowledgeLibrary()
ruleBuilder := builder.NewRuleBuilder(rbbs.knowledgeLibrary)
return ruleBuilder.BuildRuleFromResource(ruleBasedBackendServiceKLName, ruleBasedBackendServiceKLVersion, res)
}

func (rbbs RuleBasedBackendService) RuleBasedLoad(modelName string, alreadyLoadedResult *RuleBasedBackendResult, source string, optionalRequest interface{}) (*RuleBasedBackendResult, error) {
backendId, bc, err := getModelLoaderIDFromModelName(rbbs.configLoader, modelName)
if err != nil {
return nil, err
}
lmm := rbbs.modelLoader.CheckIsLoaded(backendId, true)
if lmm != nil {
return alreadyLoadedResult, nil
}
result := RuleBasedBackendResult{}
ruleBasedLoadDataCtx := ast.NewDataContext()

ruleBasedLoadDataCtx.Add("ModelLoader", rbbs.modelLoader)
ruleBasedLoadDataCtx.Add("LoadedModelCount", rbbs.modelLoader.LoadedModelCount()) // Still relevant after second line???
ruleBasedLoadDataCtx.Add("LoadedModels", rbbs.modelLoader.SortedLoadedModelMetadata())

ruleBasedLoadDataCtx.Add("ActionDefs", ruleBasedBackendResultActionDefinitions)

ruleBasedLoadDataCtx.Add("RequestedModelName", modelName)
ruleBasedLoadDataCtx.Add("Source", source)
ruleBasedLoadDataCtx.Add("Request", optionalRequest)
ruleBasedLoadDataCtx.Add("BackendConfig", bc)

ruleBasedLoadDataCtx.Add("Result", result)

knowledgeBase, err := rbbs.knowledgeLibrary.NewKnowledgeBaseInstance(ruleBasedBackendServiceKLName, ruleBasedBackendServiceKLVersion)
if err != nil {
return nil, err
}
engine := engine.NewGruleEngine()
err = engine.Execute(ruleBasedLoadDataCtx, knowledgeBase)
if err != nil {
return nil, err
}
return &result, nil

}
25 changes: 23 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/google/uuid v1.5.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hpcloud/tail v1.0.0
github.com/hyperjumptech/grule-rule-engine v1.15.0
github.com/imdario/mergo v0.3.16
github.com/mholt/archiver/v3 v3.5.1
github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c
Expand Down Expand Up @@ -60,26 +61,36 @@ require (
)

require (
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.2.0 // indirect
github.com/Microsoft/go-winio v0.6.0 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371 // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect
github.com/aymanbagabas/go-osc52 v1.0.3 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cloudflare/circl v1.3.3 // indirect
github.com/containerd/continuity v0.3.0 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/docker/cli v20.10.17+incompatible // indirect
github.com/docker/docker v20.10.7+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
github.com/go-git/go-billy/v5 v5.5.0 // indirect
github.com/go-git/go-git/v5 v5.11.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/jsonreference v0.21.0 // indirect
Expand All @@ -88,12 +99,15 @@ require (
github.com/gofiber/template v1.8.3 // indirect
github.com/gofiber/utils v1.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/huandu/xstrings v1.3.3 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
Expand All @@ -112,17 +126,21 @@ require (
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/opencontainers/runc v1.1.5 // indirect
github.com/pierrec/lz4/v4 v4.1.2 // indirect
github.com/pjbgf/sha1cd v0.3.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/sergi/go-diff v1.1.0 // indirect
github.com/shopspring/decimal v1.2.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/skeema/knownhosts v1.2.1 // indirect
github.com/spf13/cast v1.3.1 // indirect
github.com/swaggo/files/v2 v2.0.0 // indirect
github.com/ulikunitz/xz v0.5.9 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
Expand All @@ -131,12 +149,15 @@ require (
github.com/yuin/goldmark-emoji v1.0.1 // indirect
go.opentelemetry.io/otel/sdk v1.19.0 // indirect
go.opentelemetry.io/otel/trace v1.19.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.25.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/term v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect
gopkg.in/fsnotify.v1 v1.4.7 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
)

require (
Expand Down
Loading

0 comments on commit 15e6f93

Please sign in to comment.