Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Feb 20, 2024
1 parent a5f8368 commit 779b336
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 32 deletions.
25 changes: 11 additions & 14 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,42 +113,39 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...)
for i := 0; i < len(agentDeployments); i++ {
client := cs.agentMetadataClients[agentDeployments[i].Endpoint]
for _, agentDeployment := range agentDeployments {
client := cs.agentMetadataClients[agentDeployment.Endpoint]

finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployments[i])
finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment)
defer cancel()

res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpcStatus, ok := status.FromError(err)
if grpcStatus.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployments[i])
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
continue
}

if !ok {
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployments[i], err)
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err)
}

return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployments[i], err)
return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err)
}

agents := res.GetAgents()
for j := 0; j < len(agents); j++ {
supportedTaskTypes := agents[j].SupportedTaskTypes
for _, agent := range res.GetAgents() {
supportedTaskTypes := agent.SupportedTaskTypes
for _, supportedTaskType := range supportedTaskTypes {
agent := &Agent{AgentDeployment: agentDeployments[i], IsSync: agents[j].IsSync}
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedTaskType.GetName()] = map[int32]*Agent{supportedTaskType.GetVersion(): agent}
}
logger.Infof(context.Background(), "[%v] AgentDeployment is a sync agent: [%v]", agents[j].Name, agents[j].IsSync)
logger.Infof(context.Background(), "[%v] AgentDeployment supports task types: [%v]", agents[j].Name, supportedTaskTypes)
logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync)
logger.Infof(context.Background(), "[%v] supports task types: [%v]", agent.Name, supportedTaskTypes)
}
}

logger.Infof(context.Background(), "AgentDeployment registry initialized: [%v]", agentRegistry["mock_openai"][0])

return agentRegistry, nil
}

Expand Down
1 change: 0 additions & 1 deletion flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func TestInitializeClients(t *testing.T) {
},
"y": {
Endpoint: "y",
IsSync: true,
},
}
ctx := context.Background()
Expand Down
21 changes: 4 additions & 17 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) {
},
},
}
expectedOutputs, err := coreutils.MakeLiteralMap(map[string]interface{}{"x": []interface{}{1, 2}})
expectedOutputs, err := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
assert.NoError(t, err)
phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, expectedOutputs, nil, iter)
assert.Equal(t, true, phase.Phase().IsSuccess())
Expand Down Expand Up @@ -290,9 +290,8 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry {

func newMockSyncAgentPlugin() webapi.PluginEntry {
syncAgentClient := new(agentMocks.SyncAgentServiceClient)
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED}
output1, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
output2, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 2})
output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output}

stream := new(agentMocks.SyncAgentService_ExecuteTaskSyncClient)
stream.OnRecv().Return(&admin.ExecuteTaskSyncResponse{
Expand All @@ -303,18 +302,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
},
}, nil).Once()

stream.OnRecv().Return(&admin.ExecuteTaskSyncResponse{
Res: &admin.ExecuteTaskSyncResponse_Outputs{
Outputs: output1,
},
}, nil).Once()

stream.OnRecv().Return(&admin.ExecuteTaskSyncResponse{
Res: &admin.ExecuteTaskSyncResponse_Outputs{
Outputs: output2,
},
}, nil).Once()

stream.OnRecv().Return(nil, io.EOF).Once()
stream.OnSendMatch(mock.Anything).Return(nil)
stream.OnCloseSendMatch(mock.Anything).Return(nil)
Expand All @@ -323,7 +310,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {

cfg := defaultConfig
cfg.DefaultAgent.Endpoint = defaultAgentEndpoint
cfg.DefaultAgent.IsSync = true

return webapi.PluginEntry{
ID: "agent-service",
Expand All @@ -337,6 +323,7 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: syncAgentClient,
},
},
agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &AgentDeployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}},
}, nil
},
}
Expand Down

0 comments on commit 779b336

Please sign in to comment.