Skip to content

Commit

Permalink
integrate ability to use APIKeys
Browse files Browse the repository at this point in the history
Signed-off-by: greg pereira <[email protected]>
  • Loading branch information
Gregory-Pereira committed Sep 10, 2024
1 parent cbf176e commit 7bea03e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
7 changes: 7 additions & 0 deletions deploy/ansible/worker/tasks/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
precheck_endpoint_url != None and
precheck_endpoint_url | trim != ''

- name: Append api key arg for given precheck_endpoint
ansible.builtin.set_fact:
worker_args: "{{ worker_args }} --precheck-api-key {{ api_key }}"
when: api_key is defined and
api_key != None and
api_key | trim != ''

- name: Set tls-insecure if enabled
ansible.builtin.set_fact:
worker_args: "{{ worker_args }} --tls-insecure \"true\""
Expand Down
11 changes: 10 additions & 1 deletion worker/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
TlsClientCertPath string
TlsClientKeyPath string
TlsServerCaCertPath string
APIKey string
TlsInsecure bool
MaxSeed int
TaxonomyFolders = []string{"compositional_skills", "knowledge"}
Expand Down Expand Up @@ -79,6 +80,7 @@ type Worker struct {
logger *zap.SugaredLogger
job string
precheckEndpoint string
precheckAPIKey string
sdgEndpoint string
jobStart time.Time
tlsClientCertPath string
Expand All @@ -88,14 +90,15 @@ type Worker struct {
cmdRun string
}

func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, precheckAPIKey, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
return &Worker{
ctx: ctx,
pool: pool,
svc: svc,
logger: logger,
job: job,
precheckEndpoint: precheckEndpoint,
precheckAPIKey: precheckAPIKey,
sdgEndpoint: sdgEndpoint,
jobStart: time.Now(),
tlsClientCertPath: tlsClientCertPath,
Expand All @@ -115,6 +118,7 @@ func init() {
generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in")
generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory")
generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().StringVarP(&APIKey, "precheck-api-key", "", "", "The APIKey for the precheck-endpoint-url.")
generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate")
generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo")
Expand Down Expand Up @@ -201,6 +205,7 @@ var generateCmd = &cobra.Command{
}
NewJobProcessor(ctx, pool, svc, sugar, job,
PreCheckEndpointURL,
APIKey,
SdgEndpointURL,
TlsClientCertPath,
TlsClientKeyPath,
Expand Down Expand Up @@ -432,6 +437,10 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
if PreCheckEndpointURL != localEndpoint && modelName != "unknown" {
commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName)
}
if APIKey != "" {
commandStr += fmt.Sprintf(" --precheck-api-key %s", APIKey)
}

cmdArgs := strings.Fields(commandStr)
cmd := exec.Command(lab, cmdArgs...)
// Register the command for reporting/logging
Expand Down
2 changes: 2 additions & 0 deletions worker/cmd/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"precheck-api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down Expand Up @@ -214,6 +215,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"precheck-api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down

0 comments on commit 7bea03e

Please sign in to comment.