Skip to content

Commit

Permalink
feat(databricks): support override Databricks instances
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed May 13, 2024
1 parent 2f38d65 commit 9c66353
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ func TestEndToEnd(t *testing.T) {
}
databricksConfig, err := utils.MarshalObjToStruct(databricksConfDict)
assert.NoError(t, err)
sparkJob := plugins.SparkJob{DatabricksConf: databricksConfig, DatabricksToken: "token", SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"}}
sparkJob := plugins.SparkJob{
DatabricksConf: databricksConfig,
DatabricksToken: "token",
DatabricksInstance: "Foo",
SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"},
}
st, err := utils.MarshalPbToStruct(&sparkJob)
assert.NoError(t, err)
inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
Expand Down
17 changes: 10 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
}
}
databricksJob[sparkPythonTask] = map[string]interface{}{pythonFile: p.cfg.EntrypointFile, parameters: modifiedArgs}

data, err := p.sendRequest(create, databricksJob, token, "")
databricksInstance := p.cfg.DatabricksInstance
if sparkJob.DatabricksInstance != "" {
databricksInstance = sparkJob.DatabricksInstance
}
data, err := p.sendRequest(create, databricksJob, token, "", databricksInstance)
if err != nil {
return nil, nil, err
}
Expand All @@ -138,12 +141,12 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
}
runID := fmt.Sprintf("%.0f", data["run_id"])

return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil
return ResourceMetaWrapper{runID, databricksInstance, token}, nil, nil
}

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
exec := taskCtx.ResourceMeta().(ResourceMetaWrapper)
res, err := p.sendRequest(get, nil, exec.Token, exec.RunID)
res, err := p.sendRequest(get, nil, exec.Token, exec.RunID, exec.DatabricksInstance)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -175,7 +178,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return nil
}
exec := taskCtx.ResourceMeta().(ResourceMetaWrapper)
_, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID)
_, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID, exec.DatabricksInstance)

Check warning on line 181 in flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go#L181

Added line #L181 was not covered by tests
if err != nil {
return err
}
Expand All @@ -184,11 +187,11 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return nil
}

func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token string, runID string) (map[string]interface{}, error) {
func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token, runID, databricksInstance string) (map[string]interface{}, error) {
var databricksURL string
// for mocking/testing purposes
if p.cfg.databricksEndpoint == "" {
databricksURL = fmt.Sprintf("https://%v%v", p.cfg.DatabricksInstance, databricksAPI)
databricksURL = fmt.Sprintf("https://%v%v", databricksInstance, databricksAPI)
} else {
databricksURL = fmt.Sprintf("%v%v", p.cfg.databricksEndpoint, databricksAPI)
}
Expand Down
12 changes: 6 additions & 6 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestSendRequest(t *testing.T) {
}

t.Run("create a Databricks job", func(t *testing.T) {
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Equal(t, "someID", data["id"])
assert.Equal(t, "someData", data["data"])
Expand All @@ -88,7 +88,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"failed"}`)),
}, nil
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to create Databricks job with error [failed]")
})
Expand All @@ -98,7 +98,7 @@ func TestSendRequest(t *testing.T) {
assert.Equal(t, req.Method, http.MethodPost)
return nil, errors.New("failed to send request")
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to send request to Databricks platform with err: [failed to send request]")
})
Expand All @@ -111,7 +111,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`123`)),
}, nil
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to parse response with err: [json: cannot unmarshal number into Go value of type map[string]interface {}]")
})
Expand All @@ -124,7 +124,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)),
}, nil
}}
data, err := plugin.sendRequest(get, databricksJob, token, "")
data, err := plugin.sendRequest(get, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Nil(t, err)
})
Expand All @@ -137,7 +137,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)),
}, nil
}}
data, err := plugin.sendRequest(cancel, databricksJob, token, "")
data, err := plugin.sendRequest(cancel, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Nil(t, err)
})
Expand Down

0 comments on commit 9c66353

Please sign in to comment.