Skip to content

Commit

Permalink
MASS RENAME: o / options => appConfig where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-gray101 committed Feb 26, 2024
1 parent bc5609c commit c25b7d9
Show file tree
Hide file tree
Showing 15 changed files with 146 additions and 147 deletions.
26 changes: 13 additions & 13 deletions core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model"
)

func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (func() ([]float32, error), error) {
if !c.Embeddings {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
if !backendConfig.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}

modelFile := c.Model
modelFile := backendConfig.Model

grpcOpts := gRPCModelOpts(c)
grpcOpts := gRPCModelOpts(backendConfig)

var inferenceModel interface{}
var err error

opts := modelOpts(c, o, []model.Option{
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithLoadGRPCLoadModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithContext(appConfig.Context),
})

if c.Backend == "" {
if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
} else {
opts = append(opts, model.WithBackendString(c.Backend))
opts = append(opts, model.WithBackendString(backendConfig.Backend))
inferenceModel, err = loader.BackendLoader(opts...)
}
if err != nil {
Expand All @@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
switch model := inferenceModel.(type) {
case grpc.Backend:
fn = func() ([]float32, error) {
predictOptions := gRPCPredictOpts(c, loader.ModelPath)
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
if len(tokens) > 0 {
embeds := []int32{}

Expand All @@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.EmbeddingTokens = embeds

res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
Expand All @@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}
predictOptions.Embeddings = s

res, err := model.Embeddings(o.Context, predictOptions)
res, err := model.Embeddings(appConfig.Context, predictOptions)
if err != nil {
return nil, err
}
Expand Down
44 changes: 22 additions & 22 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model"
)

func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (func() error, error) {
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {

opts := modelOpts(c, o, []model.Option{
model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModel(c.Model),
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithContext(appConfig.Context),
model.WithModel(backendConfig.Model),
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
PipelineType: c.Diffusers.PipelineType,
CFGScale: c.Diffusers.CFGScale,
LoraAdapter: c.LoraAdapter,
LoraScale: c.LoraScale,
LoraBase: c.LoraBase,
IMG2IMG: c.Diffusers.IMG2IMG,
CLIPModel: c.Diffusers.ClipModel,
CLIPSubfolder: c.Diffusers.ClipSubFolder,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
CUDA: backendConfig.CUDA || backendConfig.Diffusers.CUDA,
SchedulerType: backendConfig.Diffusers.SchedulerType,
PipelineType: backendConfig.Diffusers.PipelineType,
CFGScale: backendConfig.Diffusers.CFGScale,
LoraAdapter: backendConfig.LoraAdapter,
LoraScale: backendConfig.LoraScale,
LoraBase: backendConfig.LoraBase,
IMG2IMG: backendConfig.Diffusers.IMG2IMG,
CLIPModel: backendConfig.Diffusers.ClipModel,
CLIPSubfolder: backendConfig.Diffusers.ClipSubFolder,
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
ControlNet: backendConfig.Diffusers.ControlNet,
}),
})

Expand All @@ -40,19 +40,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat

fn := func() error {
_, err := inferenceModel.GenerateImage(
o.Context,
appConfig.Context,
&proto.GenerateImageRequest{
Height: int32(height),
Width: int32(width),
Mode: int32(mode),
Step: int32(step),
Seed: int32(seed),
CLIPSkip: int32(c.Diffusers.ClipSkip),
CLIPSkip: int32(backendConfig.Diffusers.ClipSkip),
PositivePrompt: positive_prompt,
NegativePrompt: negative_prompt,
Dst: dst,
Src: src,
EnableParameters: c.Diffusers.EnableParameters,
EnableParameters: backendConfig.Diffusers.EnableParameters,
})
return err
}
Expand Down
14 changes: 7 additions & 7 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
model "github.com/go-skynet/LocalAI/pkg/model"
)

func ModelTranscription(audio, language string, ml *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (*schema.Result, error) {
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) {

opts := modelOpts(c, o, []model.Option{
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModel(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
model.WithModel(backendConfig.Model),
model.WithContext(appConfig.Context),
model.WithThreads(uint32(backendConfig.Threads)),
model.WithAssetDir(appConfig.AssetsDestination),
})

whisperModel, err := ml.BackendLoader(opts...)
Expand All @@ -33,6 +33,6 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, c config.
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
Threads: uint32(backendConfig.Threads),
})
}
18 changes: 9 additions & 9 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ func generateUniqueFileName(dir, baseName, ext string) string {
}
}

func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *config.ApplicationConfig, c config.BackendConfig) (string, *proto.Result, error) {
func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}

grpcOpts := gRPCModelOpts(c)
grpcOpts := gRPCModelOpts(backendConfig)

opts := modelOpts(config.BackendConfig{}, o, []model.Option{
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})
piperModel, err := loader.BackendLoader(opts...)
Expand All @@ -53,19 +53,19 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *con
return "", nil, fmt.Errorf("could not load piper model")
}

if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

// If the model file is not empty, we pass it joined with the model path
modelPath := ""
if modelFile != "" {
if bb != model.TransformersMusicGen {
modelPath = filepath.Join(loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil {
if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil {
return "", nil, err
}
} else {
Expand Down
1 change: 0 additions & 1 deletion core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ type ApplicationConfig struct {
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
// Metrics *metrics.Metrics

ModelLibraryURL string

Expand Down
Loading

0 comments on commit c25b7d9

Please sign in to comment.