Skip to content

Commit

Permalink
integrate ability to use APIKeys
Browse files Browse the repository at this point in the history
Signed-off-by: greg pereira <[email protected]>
  • Loading branch information
Gregory-Pereira committed Sep 10, 2024
1 parent cbf176e commit 9cc7b63
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ui/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ const (
redisQueueArchive = "archived"
)

const PreCheckEndpointURL = "https://merlinite-7b-vllm-openai.apps.fmaas-backend.fmaas.res.ibm.com/v1"
const PreCheckEndpointURL = "http://cmb-staging.asgharlabs.io:8505/v1"
const InstructLabBotUrl = "http://bot:8081"

type ApiServer struct {
apiKey string
router *gin.Engine
logger *zap.SugaredLogger
redis *redis.Client
Expand Down Expand Up @@ -342,6 +343,9 @@ func (api *ApiServer) runIlabChatCommand(question, context string) (string, erro
return "failed to retrieve the model name", err
}
commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", api.preCheckEndpointURL, modelName)
if api.apiKey != "" {
commandStr += fmt.Sprintf(" --api-key %s", api.apiKey)
}
cmdArgs := strings.Fields(commandStr)
cmd = exec.Command("ilab", cmdArgs...)
api.logger.Infof("Running in production mode with model name %s: %s", modelName, commandStr)
Expand Down Expand Up @@ -457,6 +461,7 @@ func main() {
testMode := pflag.Bool("test-mode", false, "Don't run ilab commands, just echo back the ilab command to the chat response")
listenAddress := pflag.String("listen-address", "localhost:3000", "Address to listen on")
redisAddress := pflag.String("redis-server", "localhost:6379", "Redis server address")
apiKey := pflag.String("api-key", "", "API key for the given `precheck-endpoint`")
apiUser := pflag.String("api-user", "", "API username")
apiPass := pflag.String("api-pass", "", "API password")
preCheckEndpointURL := pflag.String("precheck-endpoint", PreCheckEndpointURL, "Precheck endpoint URL")
Expand All @@ -476,6 +481,7 @@ func main() {

router := gin.Default()
svr := ApiServer{
apiKey: *apiKey,
router: router,
logger: logger,
redis: rdb,
Expand Down
11 changes: 10 additions & 1 deletion worker/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
TlsClientCertPath string
TlsClientKeyPath string
TlsServerCaCertPath string
APIKey string
TlsInsecure bool
MaxSeed int
TaxonomyFolders = []string{"compositional_skills", "knowledge"}
Expand Down Expand Up @@ -79,6 +80,7 @@ type Worker struct {
logger *zap.SugaredLogger
job string
precheckEndpoint string
apiKey string
sdgEndpoint string
jobStart time.Time
tlsClientCertPath string
Expand All @@ -88,14 +90,15 @@ type Worker struct {
cmdRun string
}

func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, apiKey, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
return &Worker{
ctx: ctx,
pool: pool,
svc: svc,
logger: logger,
job: job,
precheckEndpoint: precheckEndpoint,
apiKey: apiKey,
sdgEndpoint: sdgEndpoint,
jobStart: time.Now(),
tlsClientCertPath: tlsClientCertPath,
Expand All @@ -115,6 +118,7 @@ func init() {
generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in")
generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory")
generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().StringVarP(&APIKey, "api-key", "", "", "The APIKey for the precheck-endpoint-url.")
generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate")
generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo")
Expand Down Expand Up @@ -201,6 +205,7 @@ var generateCmd = &cobra.Command{
}
NewJobProcessor(ctx, pool, svc, sugar, job,
PreCheckEndpointURL,
APIKey,
SdgEndpointURL,
TlsClientCertPath,
TlsClientKeyPath,
Expand Down Expand Up @@ -432,6 +437,10 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
if PreCheckEndpointURL != localEndpoint && modelName != "unknown" {
commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName)
}
if APIKey != "" {
commandStr += fmt.Sprintf(" --api-key %s", APIKey)
}

cmdArgs := strings.Fields(commandStr)
cmd := exec.Command(lab, cmdArgs...)
// Register the command for reporting/logging
Expand Down
2 changes: 2 additions & 0 deletions worker/cmd/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down Expand Up @@ -214,6 +215,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down

0 comments on commit 9cc7b63

Please sign in to comment.