From 9c663535c081138e72977608f0772d5a0318f031 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 13 May 2024 23:25:44 +0800 Subject: [PATCH] feat(databricks): support override Databricks instances Signed-off-by: Kevin Su --- .../webapi/databricks/integration_test.go | 7 ++++++- .../tasks/plugins/webapi/databricks/plugin.go | 17 ++++++++++------- .../plugins/webapi/databricks/plugin_test.go | 12 ++++++------ 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go index d18f4ba79e..20ef96fb9c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go @@ -57,7 +57,12 @@ func TestEndToEnd(t *testing.T) { } databricksConfig, err := utils.MarshalObjToStruct(databricksConfDict) assert.NoError(t, err) - sparkJob := plugins.SparkJob{DatabricksConf: databricksConfig, DatabricksToken: "token", SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"}} + sparkJob := plugins.SparkJob{ + DatabricksConf: databricksConfig, + DatabricksToken: "token", + DatabricksInstance: "Foo", + SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"}, + } st, err := utils.MarshalPbToStruct(&sparkJob) assert.NoError(t, err) inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go index 6ae9a1dbe5..1793d8090a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go @@ -127,8 +127,11 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR } } databricksJob[sparkPythonTask] = map[string]interface{}{pythonFile: p.cfg.EntrypointFile, parameters: modifiedArgs} - - data, err := p.sendRequest(create, databricksJob, token, "") + databricksInstance := p.cfg.DatabricksInstance + if sparkJob.DatabricksInstance != "" { + databricksInstance = sparkJob.DatabricksInstance + } + data, err := p.sendRequest(create, databricksJob, token, "", databricksInstance) if err != nil { return nil, nil, err } @@ -138,12 +141,12 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR } runID := fmt.Sprintf("%.0f", data["run_id"]) - return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil + return ResourceMetaWrapper{runID, databricksInstance, token}, nil, nil } func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - res, err := p.sendRequest(get, nil, exec.Token, exec.RunID) + res, err := p.sendRequest(get, nil, exec.Token, exec.RunID, exec.DatabricksInstance) if err != nil { return nil, err } @@ -175,7 +178,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - _, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID) + _, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID, exec.DatabricksInstance) if err != nil { return err } @@ -184,11 +187,11 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } -func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token string, runID string) (map[string]interface{}, error) { +func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token, runID, databricksInstance string) (map[string]interface{}, error) { var databricksURL string // for mocking/testing purposes if p.cfg.databricksEndpoint == "" { - databricksURL = fmt.Sprintf("https://%v%v", p.cfg.DatabricksInstance, databricksAPI) + databricksURL = fmt.Sprintf("https://%v%v", databricksInstance, databricksAPI) } else { databricksURL = fmt.Sprintf("%v%v", p.cfg.databricksEndpoint, databricksAPI) } diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go index 228914af93..fdbd51e2c7 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go @@ -73,7 +73,7 @@ func TestSendRequest(t *testing.T) { } t.Run("create a Databricks job", func(t *testing.T) { - data, err := plugin.sendRequest(create, databricksJob, token, "") + data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.NotNil(t, data) assert.Equal(t, "someID", data["id"]) assert.Equal(t, "someData", data["data"]) @@ -88,7 +88,7 @@ func TestSendRequest(t *testing.T) { Body: ioutils.NewBytesReadCloser([]byte(`{"message":"failed"}`)), }, nil }} - data, err := plugin.sendRequest(create, databricksJob, token, "") + data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.Nil(t, data) assert.Equal(t, err.Error(), "failed to create Databricks job with error [failed]") }) @@ -98,7 +98,7 @@ func TestSendRequest(t *testing.T) { assert.Equal(t, req.Method, http.MethodPost) return nil, errors.New("failed to send request") }} - data, err := plugin.sendRequest(create, databricksJob, token, "") + data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.Nil(t, data) assert.Equal(t, err.Error(), "failed to send request to Databricks platform with err: [failed to send request]") }) @@ -111,7 +111,7 @@ func TestSendRequest(t *testing.T) { Body: ioutils.NewBytesReadCloser([]byte(`123`)), }, nil }} - data, err := plugin.sendRequest(create, databricksJob, token, "") + data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.Nil(t, data) assert.Equal(t, err.Error(), "failed to parse response with err: [json: cannot unmarshal number into Go value of type map[string]interface {}]") }) @@ -124,7 +124,7 @@ func TestSendRequest(t *testing.T) { Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), }, nil }} - data, err := plugin.sendRequest(get, databricksJob, token, "") + data, err := plugin.sendRequest(get, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.NotNil(t, data) assert.Nil(t, err) }) @@ -137,7 +137,7 @@ func TestSendRequest(t *testing.T) { Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), }, nil }} - data, err := plugin.sendRequest(cancel, databricksJob, token, "") + data, err := plugin.sendRequest(cancel, databricksJob, token, "", plugin.cfg.DatabricksInstance) assert.NotNil(t, data) assert.Nil(t, err) })