Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft feat: roles for api keys #2321

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ release/
.idea

# Generated during build
backend-assets/*
**/backend-assets/*
!backend-assets/.keep
prepare
/ggml-metal.metal
Expand Down
12 changes: 12 additions & 0 deletions configuration/roles.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"admin": ["*"],
"llm-user": ["POST|/chat/completions", "POST|/edits", "POST|/completions", "POST|/embeddings", "POST|/rerank", "GET|/models"],
"audio-user": ["POST|/audio/transcriptions", "POST|/audio/speech", "POST|/tts", "POST|/text-to-speech"],
"image-user": ["POST|/images/generations"],
"ui": ["GET|/", "GET|/browse", "GET|/browse/", "POST|/browse/search/models",
"GET|/browse/job/progress", "GET|/browse/job",
"GET|/chat", "GET|/chat/", "GET|/chat/:model",
"GET|/text2image", "GET|/text2image/", "GET|/text2image/:model",
"GET|/tts", "GET|/tts/", "GET|/tts/:model"],
"user": ["ui", "llm-user", "audio-user", "image-user"]
}
12 changes: 6 additions & 6 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ type RunCMD struct {
Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" default:"4" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"`
ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"`

Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`
Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"`
CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"`
CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"`
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
APIKeys map[string][]string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"`

ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"`
Expand Down
5 changes: 3 additions & 2 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type ApplicationConfig struct {
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
ApiKeys map[string][]string // ApiKeys maps the key itself to a list of endpoints [or roles] that the key should be permitted to access
Roles map[string][]string // Roles is a simple "shortcut" mapping a name to a list of endpoints

ModelLibraryURL string

Expand Down Expand Up @@ -271,7 +272,7 @@ func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption {
}
}

func WithApiKeys(apiKeys []string) AppOption {
func WithApiKeys(apiKeys map[string][]string) AppOption {
return func(o *ApplicationConfig) {
o.ApiKeys = apiKeys
}
Expand Down
35 changes: 26 additions & 9 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package http
import (
"embed"
"errors"
"fmt"
"net/http"
"slices"
"strings"

"github.com/go-skynet/LocalAI/pkg/utils"
Expand Down Expand Up @@ -127,33 +129,48 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi

// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
auth := func(c *fiber.Ctx) error {
if len(appConfig.ApiKeys) == 0 {
return c.Next()
}

if len(appConfig.ApiKeys) == 0 {
return c.Next()
}

defaultCaseExists := len(appConfig.ApiKeys["_"]) > 0
fmtPath := fmt.Sprintf("%s|%s", c.Route().Method, strings.Replace(c.Route().Path, "/v1", "", -1))

authHeader := readAuthHeader(c)
if authHeader == "" {
if !defaultCaseExists && authHeader == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
}

// If it's a bearer token
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
if !defaultCaseExists {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
} else {
authHeaderParts = []string{"", ""}
}
}

apiKey := authHeaderParts[1]
for _, key := range appConfig.ApiKeys {
if apiKey == key {
return c.Next()
if apiKey != "" {
for key, endpoints := range appConfig.ApiKeys {
if apiKey == key {
log.Trace().Str("key", key).Str("fmtPath", fmtPath).Msg("found a matching api key, checking permissions for fmtPath")
if slices.Contains(endpoints, "*") || slices.Contains(endpoints, fmtPath) {
return c.Next()
}
}
}
}

return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
// Check if this is a default-allow endpoint
if defaultCaseExists && slices.Contains(appConfig.ApiKeys["_"], fmtPath) {
log.Trace().Str("fmtPath", fmtPath).Msg("matching authorization key not found, but fmtPath is on the default allow list")
return c.Next()
}

return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key", "fmtPath": fmtPath, "apiKey": apiKey})
}

if appConfig.CORS {
Expand Down
82 changes: 78 additions & 4 deletions core/startup/config_file_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"time"

"github.com/fsnotify/fsnotify"
Expand All @@ -31,14 +32,20 @@ func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler
handlers: make(map[string]fileHandler),
appConfig: appConfig,
}
err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true)

err := c.Register("roles.json", readRolesJson(*appConfig), true)
if err != nil {
log.Error().Err(err).Str("file", "roles.json").Msg("unable to register config file handler")
}
err = c.Register("api_keys.json", readApiKeysJson(*appConfig), true)
if err != nil {
log.Error().Err(err).Str("file", "api_keys.json").Msg("unable to register config file handler")
}
err = c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true)
if err != nil {
log.Error().Err(err).Str("file", "external_backends.json").Msg("unable to register config file handler")
}

return c
}

Expand Down Expand Up @@ -135,26 +142,93 @@ func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler {

if len(fileContent) > 0 {
// Parse JSON content from the file
var fileKeys []string
var fileKeys map[string][]string
err := json.Unmarshal(fileContent, &fileKeys)
if err != nil {
return err
}

log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile")
appConfig.ApiKeys = startupAppConfig.ApiKeys
if appConfig.ApiKeys == nil {
appConfig.ApiKeys = map[string][]string{}
}

appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...)
log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile")
for key, rawFileEndpoints := range fileKeys {
appConfig.ApiKeys[key] = append(startupAppConfig.ApiKeys[key], rawFileEndpoints...)
}
} else {
log.Trace().Msg("no API keys discovered from dynamic config file")
appConfig.ApiKeys = startupAppConfig.ApiKeys
}

// next, clean and process the ApiKeys for roles, duplicates, and *
// This is registered to run at startup, so will evaluate roles passed in as startupAppConfig
// quick version for now, this can be improved later
for key, endpoints := range appConfig.ApiKeys {
// Check if the starting point is enough to know the final answer
if slices.Contains(endpoints, "*") {
appConfig.ApiKeys[key] = []string{"*"}
continue
}

for { // We loop around here a second time if we make a change -- this ensures we unroll nested roles
isClean := true
for role, roleEndpoints := range appConfig.Roles {
index := slices.Index(appConfig.ApiKeys[key], role)
if index != -1 {
appConfig.ApiKeys[key] = slices.Replace(appConfig.ApiKeys[key], index, index+1, roleEndpoints...)
isClean = false
}
}
if isClean {
break
}
}
// Check if we have a "*"" yet
if slices.Contains(appConfig.ApiKeys[key], "*") {
appConfig.ApiKeys[key] = []string{"*"}
continue
}
// At this point, Sort+Compact is a simple way to deduplicate the endpoint list, no matter how the roles overlap
slices.Sort(appConfig.ApiKeys[key])
appConfig.ApiKeys[key] = slices.Compact(appConfig.ApiKeys[key])
}

log.Trace().Int("numKeys", len(appConfig.ApiKeys)).Msg("total api keys after processing")
return nil
}

return handler
}

func readRolesJson(startupAppConfig config.ApplicationConfig) fileHandler {
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
log.Debug().Msg("processing roles runtime update")
log.Trace().Int("numRoles", len(startupAppConfig.Roles)).Msg("roles provided at startup")

if len(fileContent) > 0 {
// Parse JSON content from the file
var fileRoles map[string][]string // Roles is a simple "shortcut" mapping a name to a list of endpoints
err := json.Unmarshal(fileContent, &fileRoles)
if err != nil {
return err
}

log.Trace().Int("numRoles", len(fileRoles)).Msg("discovered roles from roles dynamic config dile")

appConfig.Roles = fileRoles
} else {
log.Trace().Msg("no roles discovered from dynamic config file")
appConfig.Roles = startupAppConfig.Roles
}
log.Trace().Int("numRoles", len(appConfig.Roles)).Msg("total roles after processing")
return nil
}

return handler
}

func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler {
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
log.Debug().Msg("processing external_backends.json")
Expand Down