Skip to content

Commit

Permalink
Merge pull request #5 from cloudwego/refactor/cbs
Browse files Browse the repository at this point in the history
refactor: SetRunInfo
  • Loading branch information
shentongmartin authored Dec 17, 2024
2 parents e64a0e9 + 5964f40 commit 87f38b5
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 76 deletions.
34 changes: 3 additions & 31 deletions callbacks/aspect_inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ func OnError(ctx context.Context, err error) context.Context {
return ctx
}

// SwitchRunInfo updates the RunInfo in the context if a previous RunInfo already exists for that context.
func SwitchRunInfo(ctx context.Context, info *RunInfo) context.Context {
// SetRunInfo sets the RunInfo to be passed to Handler.
func SetRunInfo(ctx context.Context, info *RunInfo) context.Context {
cbm, ok := managerFromCtx(ctx)
if !ok {
return ctx
Expand All @@ -195,7 +195,7 @@ func SwitchRunInfo(ctx context.Context, info *RunInfo) context.Context {
}

// InitCallbacks initializes a new context with the provided RunInfo and handlers.
// If successful, it returns a new context containing RunInfo and handlers; otherwise, it returns a context with a nil manager.
// Any previously set RunInfo and Handlers for this ctx will be overwritten.
func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context {
mgr, ok := newManager(info, handlers...)
if ok {
Expand All @@ -204,31 +204,3 @@ func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) cont

return ctxWithManager(ctx, nil)
}

// Needed checks if any callback handlers exist in this context.
func Needed(ctx context.Context) bool {
_, cbmOK := managerFromCtx(ctx)
return cbmOK
}

// NeededForTiming checks if any callback handlers exist in this context that are needed for this specific timing.
func NeededForTiming(ctx context.Context, timing CallbackTiming) bool {
mgr, ok := managerFromCtx(ctx)
if !ok {
return false
}

if len(mgr.handlers) == 0 {
return false
}

for i := 0; i < len(mgr.handlers); i++ {
handler := mgr.handlers[i]
timingChecker, ok := handler.(TimingChecker)
if !ok || timingChecker.Needed(ctx, mgr.runInfo, timing) {
return true
}
}

return false
}
6 changes: 0 additions & 6 deletions callbacks/aspect_inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,5 @@ func TestAspectInject(t *testing.T) {
}
nosr.Close()
assert.Equal(t, 186, cnt)

assert.True(t, NeededForTiming(ctx, TimingOnStart))
assert.True(t, NeededForTiming(ctx, TimingOnEnd))
assert.True(t, NeededForTiming(ctx, TimingOnError))
assert.True(t, NeededForTiming(ctx, TimingOnStartWithStreamInput))
assert.True(t, NeededForTiming(ctx, TimingOnEndWithStreamOutput))
})
}
14 changes: 13 additions & 1 deletion callbacks/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ func (m *manager) withRunInfo(runInfo *RunInfo) *manager {
}
}

func (m *manager) appendHandlers(handlers ...Handler) *manager {
if m == nil {
return nil
}

return &manager{
handlers: append(m.handlers, handlers...),
runInfo: m.runInfo,
}
}

// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead
func ManagerFromCtx(ctx context.Context) (*Manager, bool) {
internalM, ok := managerFromCtx(ctx)
Expand All @@ -104,7 +115,8 @@ func ManagerFromCtx(ctx context.Context) (*Manager, bool) {
}

func managerFromCtx(ctx context.Context) (*manager, bool) {
m, ok := ctx.Value(internal.CtxManagerKey{}).(*manager)
v := ctx.Value(internal.CtxManagerKey{})
m, ok := v.(*manager)
if ok && m != nil {
return &manager{
handlers: m.handlers,
Expand Down
14 changes: 7 additions & 7 deletions callbacks/template/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,19 @@ func TestNewComponentTemplate(t *testing.T) {
callbacks.OnStart(ctx, nil)
assert.Equal(t, 22, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt})
callbacks.OnStart(ctx, nil)
assert.Equal(t, 23, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer})
callbacks.OnEnd(ctx, nil)
assert.Equal(t, 23, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding})
callbacks.OnError(ctx, nil)
assert.Equal(t, 24, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader})
callbacks.OnStart(ctx, nil)
assert.Equal(t, 24, cnt)

Expand Down Expand Up @@ -239,11 +239,11 @@ func TestNewComponentTemplate(t *testing.T) {
callbacks.OnEnd(ctx, nil)
assert.Equal(t, 25, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer})
callbacks.OnStart(ctx, nil)
assert.Equal(t, 26, cnt)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader})
callbacks.OnEnd(ctx, nil)
assert.Equal(t, 27, cnt)
})
Expand Down Expand Up @@ -328,7 +328,7 @@ func TestNewComponentTemplate(t *testing.T) {
callbacks.OnEndWithStreamOutput(ctx, &schema.StreamReader[callbacks.CallbackOutput]{})
assert.Equal(t, 10, cntf)

ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever})
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever})
callbacks.OnStart(ctx, nil)
assert.Equal(t, 1, cnt)
})
Expand Down
4 changes: 2 additions & 2 deletions compose/tool_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, er
}

func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...tool.Option) {
ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{
Name: task.name,
Type: task.meta.componentImplType,
Component: task.meta.component,
Expand All @@ -176,7 +176,7 @@ func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...to
}

func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...tool.Option) {
ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{
ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{
Name: task.name,
Type: task.meta.componentImplType,
Component: task.meta.component,
Expand Down
25 changes: 0 additions & 25 deletions compose/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ func mergeValues(vs []any) (any, error) {

func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O, TOption] {
return func(ctx context.Context, input I, opts ...TOption) (output O, err error) {
if !callbacks.Needed(ctx) {
return i(ctx, input, opts...)
}

defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
Expand All @@ -111,10 +107,6 @@ func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O

func genericInvokeWithCallbacks(i invoke) invoke {
return func(ctx context.Context, input any, opts ...any) (output any, err error) {
if !callbacks.Needed(ctx) {
return i(ctx, input, opts...)
}

defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
Expand All @@ -133,10 +125,6 @@ func genericInvokeWithCallbacks(i invoke) invoke {

func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O, TOption] {
return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) {
if !callbacks.Needed(ctx) {
return s(ctx, input, opts...)
}

ctx = callbacks.OnStart(ctx, input)

output, err = s(ctx, input, opts...)
Expand All @@ -153,10 +141,6 @@ func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O

func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I, O, TOption] {
return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) {
if !callbacks.Needed(ctx) {
return c(ctx, input, opts...)
}

defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
Expand All @@ -174,11 +158,6 @@ func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I
func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Transform[I, O, TOption] {
return func(ctx context.Context, input *schema.StreamReader[I],
opts ...TOption) (output *schema.StreamReader[O], err error) {

if !callbacks.Needed(ctx) {
return t(ctx, input, opts...)
}

ctx, input = callbacks.OnStartWithStreamInput(ctx, input)

output, err = t(ctx, input, opts...)
Expand All @@ -195,10 +174,6 @@ func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Trans

func genericTransformWithCallbacks(t transform) transform {
return func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) {
if !callbacks.Needed(ctx) {
return t(ctx, input, opts...)
}

inArr := input.copy(2)
is, ok := unpackStreamReader[callbacks.CallbackInput](inArr[1])
if !ok { // unexpected
Expand Down
2 changes: 1 addition & 1 deletion flow/retriever/multiquery/multi_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,5 @@ func ctxWithFusionRunInfo(ctx context.Context) context.Context {

runInfo.Name = runInfo.Type + string(runInfo.Component)

return callbacks.SwitchRunInfo(ctx, runInfo)
return callbacks.SetRunInfo(ctx, runInfo)
}
4 changes: 2 additions & 2 deletions flow/retriever/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func ctxWithRouterRunInfo(ctx context.Context) context.Context {

runInfo.Name = runInfo.Type + string(runInfo.Component)

return callbacks.SwitchRunInfo(ctx, runInfo)
return callbacks.SetRunInfo(ctx, runInfo)
}

func ctxWithFusionRunInfo(ctx context.Context) context.Context {
Expand All @@ -189,5 +189,5 @@ func ctxWithFusionRunInfo(ctx context.Context) context.Context {

runInfo.Name = runInfo.Type + string(runInfo.Component)

return callbacks.SwitchRunInfo(ctx, runInfo)
return callbacks.SetRunInfo(ctx, runInfo)
}
2 changes: 1 addition & 1 deletion flow/retriever/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ func ctxWithRetrieverRunInfo(ctx context.Context, r retriever.Retriever) context

runInfo.Name = runInfo.Type + string(runInfo.Component)

return callbacks.SwitchRunInfo(ctx, runInfo)
return callbacks.SetRunInfo(ctx, runInfo)
}

0 comments on commit 87f38b5

Please sign in to comment.