Skip to content

Commit

Permalink
Fix dask, pytorch, tensorflow, and mpi tests
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 committed Apr 8, 2024
1 parent 3fd60ed commit c72b69c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
20 changes: 17 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dask

import (
"context"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -197,6 +198,19 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
overrides.OnGetContainerImage().Return("")
taskExecutionMetadata.OnGetOverrides().Return(overrides)
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
return 0
},
func(v interface{}) error {
return nil
})

taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock)
return taskCtx
}

Expand Down Expand Up @@ -699,21 +713,21 @@ func TestGetTaskPhaseDask(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing)
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, taskPhase.Info().Logs)
assert.NotNil(t, taskPhase.Info().Logs)
assert.Nil(t, err)

taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated))
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing)
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, taskPhase.Info().Logs)
assert.NotNil(t, taskPhase.Info().Logs)
assert.Nil(t, err)

taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobClusterCreated))
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing)
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, taskPhase.Info().Logs)
assert.NotNil(t, taskPhase.Info().Logs)
assert.Nil(t, err)

taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobRunning))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pytorch
import (
"context"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -176,6 +177,19 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
return 0
},
func(v interface{}) error {
return nil
})

taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock)
return taskCtx
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tensorflow
import (
"context"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -171,6 +172,19 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
return 0
},
func(v interface{}) error {
return nil
})

taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock)
return taskCtx
}

Expand Down
15 changes: 15 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ray

import (
"context"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -707,6 +708,20 @@ func newPluginContext() k8s.PluginContext {
tskCtx := &mocks.TaskExecutionMetadata{}
tskCtx.OnGetTaskExecutionID().Return(taskExecID)
plg.OnTaskExecutionMetadata().Return(tskCtx)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
return 0
},
func(v interface{}) error {
return nil
})

plg.OnPluginStateReader().Return(&pluginStateReaderMock)

return plg
}

Expand Down

0 comments on commit c72b69c

Please sign in to comment.