From 5964f402d9a9839056be31478447ae68388621d8 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 17 Dec 2024 18:07:13 +0800 Subject: [PATCH] refactor: SetRunInfo Change-Id: Ic602e6d3af3dec90d4ff26bc194e764a38617c7a --- callbacks/aspect_inject.go | 34 +++--------------------- callbacks/aspect_inject_test.go | 6 ----- callbacks/manager.go | 14 +++++++++- callbacks/template/template_test.go | 14 +++++----- compose/tool_node.go | 4 +-- compose/utils.go | 25 ----------------- flow/retriever/multiquery/multi_query.go | 2 +- flow/retriever/router/router.go | 4 +-- flow/retriever/utils/utils.go | 2 +- 9 files changed, 29 insertions(+), 76 deletions(-) diff --git a/callbacks/aspect_inject.go b/callbacks/aspect_inject.go index 2ec6836..7e99098 100644 --- a/callbacks/aspect_inject.go +++ b/callbacks/aspect_inject.go @@ -184,8 +184,8 @@ func OnError(ctx context.Context, err error) context.Context { return ctx } -// SwitchRunInfo updates the RunInfo in the context if a previous RunInfo already exists for that context. -func SwitchRunInfo(ctx context.Context, info *RunInfo) context.Context { +// SetRunInfo sets the RunInfo to be passed to Handler. +func SetRunInfo(ctx context.Context, info *RunInfo) context.Context { cbm, ok := managerFromCtx(ctx) if !ok { return ctx @@ -195,7 +195,7 @@ func SwitchRunInfo(ctx context.Context, info *RunInfo) context.Context { } // InitCallbacks initializes a new context with the provided RunInfo and handlers. -// If successful, it returns a new context containing RunInfo and handlers; otherwise, it returns a context with a nil manager. +// Any previously set RunInfo and Handlers for this ctx will be overwritten. func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context { mgr, ok := newManager(info, handlers...) if ok { @@ -204,31 +204,3 @@ func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) cont return ctxWithManager(ctx, nil) } - -// Needed checks if any callback handlers exist in this context. -func Needed(ctx context.Context) bool { - _, cbmOK := managerFromCtx(ctx) - return cbmOK -} - -// NeededForTiming checks if any callback handlers exist in this context that are needed for this specific timing. -func NeededForTiming(ctx context.Context, timing CallbackTiming) bool { - mgr, ok := managerFromCtx(ctx) - if !ok { - return false - } - - if len(mgr.handlers) == 0 { - return false - } - - for i := 0; i < len(mgr.handlers); i++ { - handler := mgr.handlers[i] - timingChecker, ok := handler.(TimingChecker) - if !ok || timingChecker.Needed(ctx, mgr.runInfo, timing) { - return true - } - } - - return false -} diff --git a/callbacks/aspect_inject_test.go b/callbacks/aspect_inject_test.go index 0a3af79..e0f5d80 100644 --- a/callbacks/aspect_inject_test.go +++ b/callbacks/aspect_inject_test.go @@ -181,11 +181,5 @@ func TestAspectInject(t *testing.T) { } nosr.Close() assert.Equal(t, 186, cnt) - - assert.True(t, NeededForTiming(ctx, TimingOnStart)) - assert.True(t, NeededForTiming(ctx, TimingOnEnd)) - assert.True(t, NeededForTiming(ctx, TimingOnError)) - assert.True(t, NeededForTiming(ctx, TimingOnStartWithStreamInput)) - assert.True(t, NeededForTiming(ctx, TimingOnEndWithStreamOutput)) }) } diff --git a/callbacks/manager.go b/callbacks/manager.go index 90237c1..be42cb1 100644 --- a/callbacks/manager.go +++ b/callbacks/manager.go @@ -91,6 +91,17 @@ func (m *manager) withRunInfo(runInfo *RunInfo) *manager { } } +func (m *manager) appendHandlers(handlers ...Handler) *manager { + if m == nil { + return nil + } + + return &manager{ + handlers: append(m.handlers, handlers...), + runInfo: m.runInfo, + } +} + // Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead func ManagerFromCtx(ctx context.Context) (*Manager, bool) { internalM, ok := managerFromCtx(ctx) @@ -104,7 +115,8 @@ func ManagerFromCtx(ctx context.Context) (*Manager, bool) { } func managerFromCtx(ctx context.Context) (*manager, bool) { - m, ok := ctx.Value(internal.CtxManagerKey{}).(*manager) + v := ctx.Value(internal.CtxManagerKey{}) + m, ok := v.(*manager) if ok && m != nil { return &manager{ handlers: m.handlers, diff --git a/callbacks/template/template_test.go b/callbacks/template/template_test.go index c9f03b5..a6c86e4 100644 --- a/callbacks/template/template_test.go +++ b/callbacks/template/template_test.go @@ -176,19 +176,19 @@ func TestNewComponentTemplate(t *testing.T) { callbacks.OnStart(ctx, nil) assert.Equal(t, 22, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) callbacks.OnStart(ctx, nil) assert.Equal(t, 23, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd(ctx, nil) assert.Equal(t, 23, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) assert.Equal(t, 24, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart(ctx, nil) assert.Equal(t, 24, cnt) @@ -239,11 +239,11 @@ func TestNewComponentTemplate(t *testing.T) { callbacks.OnEnd(ctx, nil) assert.Equal(t, 25, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnStart(ctx, nil) assert.Equal(t, 26, cnt) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd(ctx, nil) assert.Equal(t, 27, cnt) }) @@ -328,7 +328,7 @@ func TestNewComponentTemplate(t *testing.T) { callbacks.OnEndWithStreamOutput(ctx, &schema.StreamReader[callbacks.CallbackOutput]{}) assert.Equal(t, 10, cntf) - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}) + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}) callbacks.OnStart(ctx, nil) assert.Equal(t, 1, cnt) }) diff --git a/compose/tool_node.go b/compose/tool_node.go index 03685ff..1e38a11 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -167,7 +167,7 @@ func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, er } func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...tool.Option) { - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{ + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{ Name: task.name, Type: task.meta.componentImplType, Component: task.meta.component, @@ -176,7 +176,7 @@ func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...to } func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...tool.Option) { - ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{ + ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{ Name: task.name, Type: task.meta.componentImplType, Component: task.meta.component, diff --git a/compose/utils.go b/compose/utils.go index 4907913..9a0ef9f 100644 --- a/compose/utils.go +++ b/compose/utils.go @@ -89,10 +89,6 @@ func mergeValues(vs []any) (any, error) { func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { - if !callbacks.Needed(ctx) { - return i(ctx, input, opts...) - } - defer func() { if err != nil { _ = callbacks.OnError(ctx, err) @@ -111,10 +107,6 @@ func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O func genericInvokeWithCallbacks(i invoke) invoke { return func(ctx context.Context, input any, opts ...any) (output any, err error) { - if !callbacks.Needed(ctx) { - return i(ctx, input, opts...) - } - defer func() { if err != nil { _ = callbacks.OnError(ctx, err) @@ -133,10 +125,6 @@ func genericInvokeWithCallbacks(i invoke) invoke { func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O, TOption] { return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { - if !callbacks.Needed(ctx) { - return s(ctx, input, opts...) - } - ctx = callbacks.OnStart(ctx, input) output, err = s(ctx, input, opts...) @@ -153,10 +141,6 @@ func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { - if !callbacks.Needed(ctx) { - return c(ctx, input, opts...) - } - defer func() { if err != nil { _ = callbacks.OnError(ctx, err) @@ -174,11 +158,6 @@ func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Transform[I, O, TOption] { return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { - - if !callbacks.Needed(ctx) { - return t(ctx, input, opts...) - } - ctx, input = callbacks.OnStartWithStreamInput(ctx, input) output, err = t(ctx, input, opts...) @@ -195,10 +174,6 @@ func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Trans func genericTransformWithCallbacks(t transform) transform { return func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { - if !callbacks.Needed(ctx) { - return t(ctx, input, opts...) - } - inArr := input.copy(2) is, ok := unpackStreamReader[callbacks.CallbackInput](inArr[1]) if !ok { // unexpected diff --git a/flow/retriever/multiquery/multi_query.go b/flow/retriever/multiquery/multi_query.go index b51b40c..dd5ef83 100644 --- a/flow/retriever/multiquery/multi_query.go +++ b/flow/retriever/multiquery/multi_query.go @@ -207,5 +207,5 @@ func ctxWithFusionRunInfo(ctx context.Context) context.Context { runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.SetRunInfo(ctx, runInfo) } diff --git a/flow/retriever/router/router.go b/flow/retriever/router/router.go index 3cb131a..0fe8f56 100644 --- a/flow/retriever/router/router.go +++ b/flow/retriever/router/router.go @@ -178,7 +178,7 @@ func ctxWithRouterRunInfo(ctx context.Context) context.Context { runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.SetRunInfo(ctx, runInfo) } func ctxWithFusionRunInfo(ctx context.Context) context.Context { @@ -189,5 +189,5 @@ func ctxWithFusionRunInfo(ctx context.Context) context.Context { runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.SetRunInfo(ctx, runInfo) } diff --git a/flow/retriever/utils/utils.go b/flow/retriever/utils/utils.go index 482a38e..ce57234 100644 --- a/flow/retriever/utils/utils.go +++ b/flow/retriever/utils/utils.go @@ -79,5 +79,5 @@ func ctxWithRetrieverRunInfo(ctx context.Context, r retriever.Retriever) context runInfo.Name = runInfo.Type + string(runInfo.Component) - return callbacks.SwitchRunInfo(ctx, runInfo) + return callbacks.SetRunInfo(ctx, runInfo) }