diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index 86ca034e7b..ce8e16f1fc 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -2,6 +2,7 @@ package dask import ( "context" + "reflect" "testing" "time" @@ -197,6 +198,19 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc overrides.OnGetContainerImage().Return("") taskExecutionMetadata.OnGetOverrides().Return(overrides) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -699,21 +713,21 @@ func TestGetTaskPhaseDask(t *testing.T) { assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobClusterCreated)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobRunning)) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index aea096e4ba..68f8f5b65d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -3,6 +3,7 @@ package pytorch import ( "context" "fmt" + "reflect" "testing" "time" @@ -176,6 +177,19 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1. taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 80e95871d1..f6c0afb068 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -3,6 +3,7 @@ package tensorflow import ( "context" "fmt" + "reflect" "testing" "time" @@ -171,6 +172,19 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 69f65c19ef..42ab0bf5ac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -2,6 +2,7 @@ package ray import ( "context" + "reflect" "testing" "time" @@ -707,6 +708,20 @@ func newPluginContext() k8s.PluginContext { tskCtx := &mocks.TaskExecutionMetadata{} tskCtx.OnGetTaskExecutionID().Return(taskExecID) plg.OnTaskExecutionMetadata().Return(tskCtx) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + plg.OnPluginStateReader().Return(&pluginStateReaderMock) + return plg }