forked from mudler/LocalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Dave Lee <[email protected]>
- Loading branch information
1 parent
6b07ded
commit 1c1ee61
Showing
9 changed files
with
373 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
[ | ||
{ | ||
"name": "OnlyTwoLT2", | ||
"desc": "Test Rule: Only allow 2 backends to be loaded at once, unload the least recently used if a 3rd request comes in", | ||
"salience": 10, | ||
"when": [ | ||
"Result.Action == ActionDefs.Blank", | ||
"LoadedModelCount <= 1" | ||
], | ||
"then": [ | ||
"Result.Action = ActionDefs.Continue", | ||
"Result.ModelName = RequestedModelName" | ||
] | ||
}, | ||
{ | ||
"name": "OnlyTwoGT2", | ||
"desc": "Test Rule: Only allow 2 backends to be loaded at once, unload the least recently used if a 3rd request comes in", | ||
"salience": 10, | ||
"when": [ | ||
"Result.Action == ActionDefs.Blank", | ||
"LoadedModelCount >= 2" | ||
], | ||
"then": [ | ||
"Log(\"Too many backends in use, unloading least recently used...\")", | ||
"LRU = LoadedModels[0].ModelName", | ||
"ModelLoader.ShutdownModel(LRU)", | ||
"Result.Action = ActionDefs.Continue", | ||
"Result.ModelName = RequestedModelName" | ||
] | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package services | ||
|
||
import ( | ||
"os" | ||
"path" | ||
|
||
"github.com/go-skynet/LocalAI/core/config" | ||
"github.com/go-skynet/LocalAI/pkg/model" | ||
"github.com/hyperjumptech/grule-rule-engine/ast" | ||
"github.com/hyperjumptech/grule-rule-engine/builder" | ||
"github.com/hyperjumptech/grule-rule-engine/engine" | ||
"github.com/hyperjumptech/grule-rule-engine/pkg" | ||
) | ||
|
||
const ruleBasedBackendServiceKLName = "RuleBasedBackendService" | ||
const ruleBasedBackendServiceKLVersion = "0.0.1" | ||
|
||
type RuleBasedBackendResult struct { | ||
Action string | ||
ModelName string | ||
// TODO other? | ||
} | ||
|
||
type ruleBasedBackendResultActionDefinitionsStruct struct { | ||
Blank string | ||
Continue string | ||
Error string | ||
Enqueue string | ||
} | ||
|
||
var ruleBasedBackendResultActionDefinitions := ruleBasedBackendResultActionDefinitionsStruct{ | ||
Blank: "", | ||
Continue: "continue", | ||
Error: "error", | ||
Enqueue: "enqueue", | ||
} | ||
|
||
type RuleBasedBackendService struct { | ||
configLoader *config.BackendConfigLoader | ||
modelLoader *model.ModelLoader | ||
appConfig *config.ApplicationConfig | ||
knowledgeLibrary *ast.KnowledgeLibrary | ||
} | ||
|
||
func NewRuleBasedBackendService(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) RuleBasedBackendService { | ||
rbbs := RuleBasedBackendService{ | ||
configLoader: configLoader, | ||
modelLoader: modelLoader, | ||
appConfig: appConfig, | ||
} | ||
|
||
// TODO: Phase 2 is to have bundled rule sets for common scenarios, such as always allow, SINGLE_BACKEND, only allowing authorized requests to load new backends, etc | ||
// For now, no settings for that, always use a custom json file for testing. | ||
res, err := rbbs.getExternalRuleFileResource() | ||
if err != nil { | ||
rbbs.ReloadRules(res) | ||
} | ||
|
||
return rbbs | ||
} | ||
|
||
func (rbbs RuleBasedBackendService) getExternalRuleFileResource() (pkg.Resource, error) { | ||
ruleFilePath := path.Join(rbbs.appConfig.ConfigsDir, "backend_rules.json") | ||
f, err := os.Open(ruleFilePath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
underlying := pkg.NewReaderResource(f) | ||
resource, err := pkg.NewJSONResourceFromResource(underlying) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return resource, nil | ||
} | ||
|
||
func (rbbs RuleBasedBackendService) ReloadRules(res pkg.Resource) error { | ||
rbbs.knowledgeLibrary = ast.NewKnowledgeLibrary() | ||
ruleBuilder := builder.NewRuleBuilder(rbbs.knowledgeLibrary) | ||
return ruleBuilder.BuildRuleFromResource(ruleBasedBackendServiceKLName, ruleBasedBackendServiceKLVersion, res) | ||
} | ||
|
||
func (rbbs RuleBasedBackendService) RuleBasedLoad(modelName string, alreadyLoadedResult *RuleBasedBackendResult, source string, optionalRequest interface{}) (*RuleBasedBackendResult, error) { | ||
backendId, bc, err := getModelLoaderIDFromModelName(rbbs.configLoader, modelName) | ||
if err != nil { | ||
return nil, err | ||
} | ||
lmm := rbbs.modelLoader.CheckIsLoaded(backendId, true) | ||
if lmm != nil { | ||
return alreadyLoadedResult, nil | ||
} | ||
result := RuleBasedBackendResult{} | ||
ruleBasedLoadDataCtx := ast.NewDataContext() | ||
|
||
ruleBasedLoadDataCtx.Add("ModelLoader", rbbs.modelLoader) | ||
ruleBasedLoadDataCtx.Add("LoadedModelCount", rbbs.modelLoader.LoadedModelCount()) | ||
ruleBasedLoadDataCtx.Add("LoadedModels", rbbs.modelLoader.SortedLoadedModelMetadata()) | ||
|
||
ruleBasedLoadDataCtx.Add("ActionDefs", ruleBasedBackendResultActionDefinitions) | ||
|
||
ruleBasedLoadDataCtx.Add("RequestedModelName", modelName) | ||
ruleBasedLoadDataCtx.Add("Source", source) | ||
ruleBasedLoadDataCtx.Add("Request", optionalRequest) | ||
|
||
ruleBasedLoadDataCtx.Add("BackendConfig", bc) | ||
ruleBasedLoadDataCtx.Add("AppConfig", rbbs.appConfig) | ||
|
||
ruleBasedLoadDataCtx.Add("Result", result) | ||
|
||
knowledgeBase, err := rbbs.knowledgeLibrary.NewKnowledgeBaseInstance(ruleBasedBackendServiceKLName, ruleBasedBackendServiceKLVersion) | ||
if err != nil { | ||
return nil, err | ||
} | ||
engine := engine.NewGruleEngine() | ||
err = engine.Execute(ruleBasedLoadDataCtx, knowledgeBase) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &result, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.