diff --git a/ui/apiserver/apiserver.go b/ui/apiserver/apiserver.go index aa78196..9c0fec0 100644 --- a/ui/apiserver/apiserver.go +++ b/ui/apiserver/apiserver.go @@ -25,10 +25,11 @@ const ( redisQueueArchive = "archived" ) -const PreCheckEndpointURL = "https://merlinite-7b-vllm-openai.apps.fmaas-backend.fmaas.res.ibm.com/v1" +const PreCheckEndpointURL = "http://cmb-staging.asgharlabs.io:8505/v1" const InstructLabBotUrl = "http://bot:8081" type ApiServer struct { + apiKey string router *gin.Engine logger *zap.SugaredLogger redis *redis.Client @@ -342,6 +343,9 @@ func (api *ApiServer) runIlabChatCommand(question, context string) (string, erro return "failed to retrieve the model name", err } commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", api.preCheckEndpointURL, modelName) + if api.apiKey != "" { + commandStr += fmt.Sprintf(" --api-key %s", api.apiKey) + } cmdArgs := strings.Fields(commandStr) cmd = exec.Command("ilab", cmdArgs...) api.logger.Infof("Running in production mode with model name %s: %s", modelName, commandStr) @@ -457,6 +461,7 @@ func main() { testMode := pflag.Bool("test-mode", false, "Don't run ilab commands, just echo back the ilab command to the chat response") listenAddress := pflag.String("listen-address", "localhost:3000", "Address to listen on") redisAddress := pflag.String("redis-server", "localhost:6379", "Redis server address") + apiKey := pflag.String("api-key", "", "API key for the given `precheck-endpoint`") apiUser := pflag.String("api-user", "", "API username") apiPass := pflag.String("api-pass", "", "API password") preCheckEndpointURL := pflag.String("precheck-endpoint", PreCheckEndpointURL, "Precheck endpoint URL") @@ -476,6 +481,7 @@ func main() { router := gin.Default() svr := ApiServer{ + apiKey: *apiKey, router: router, logger: logger, redis: rdb, diff --git a/worker/cmd/generate.go b/worker/cmd/generate.go index 4fcc650..45f08a2 100644 --- a/worker/cmd/generate.go +++ b/worker/cmd/generate.go @@ -46,6 +46,7 @@ var ( TlsClientCertPath string TlsClientKeyPath string TlsServerCaCertPath string + APIKey string TlsInsecure bool MaxSeed int TaxonomyFolders = []string{"compositional_skills", "knowledge"} @@ -79,6 +80,7 @@ type Worker struct { logger *zap.SugaredLogger job string precheckEndpoint string + apiKey string sdgEndpoint string jobStart time.Time tlsClientCertPath string @@ -88,7 +90,7 @@ type Worker struct { cmdRun string } -func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { +func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, apiKey, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { return &Worker{ ctx: ctx, pool: pool, @@ -96,6 +98,7 @@ func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logg logger: logger, job: job, precheckEndpoint: precheckEndpoint, + apiKey: apiKey, sdgEndpoint: sdgEndpoint, jobStart: time.Now(), tlsClientCertPath: tlsClientCertPath, @@ -115,6 +118,7 @@ func init() { generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in") generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory") generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "", "Endpoint hosting the model API. Default, it assumes the model is served locally.") + generateCmd.Flags().StringVarP(&APIKey, "api-key", "", "", "The APIKey for the precheck-endpoint-url.") generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.") generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate") generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo") @@ -201,6 +205,7 @@ var generateCmd = &cobra.Command{ } NewJobProcessor(ctx, pool, svc, sugar, job, PreCheckEndpointURL, + APIKey, SdgEndpointURL, TlsClientCertPath, TlsClientKeyPath, @@ -432,6 +437,10 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { if PreCheckEndpointURL != localEndpoint && modelName != "unknown" { commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName) } + if APIKey != "" { + commandStr += fmt.Sprintf(" --api-key %s", APIKey) + } + cmdArgs := strings.Fields(commandStr) cmd := exec.Command(lab, cmdArgs...) // Register the command for reporting/logging diff --git a/worker/cmd/generate_test.go b/worker/cmd/generate_test.go index 6102c18..46118d9 100644 --- a/worker/cmd/generate_test.go +++ b/worker/cmd/generate_test.go @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + "api-key", "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem", @@ -214,6 +215,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + "api-key", "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem",