From 5a3d3583b5583af9143aba6d5edd68bd01bac0e0 Mon Sep 17 00:00:00 2001 From: "shentong.martin" Date: Tue, 17 Dec 2024 19:43:55 +0800 Subject: [PATCH] refactor: move callback template to utils Change-Id: I0296020a9f49565084281d71f79d47df25da5fef --- callbacks/handler_builder.go | 2 +- callbacks/interface.go | 2 +- callbacks/template/default.go | 51 --- components/document/callback_extra_loader.go | 23 -- .../document/callback_extra_transformer.go | 23 -- components/embedding/callback_extra.go | 23 -- components/indexer/callback_extra.go | 23 -- components/model/callback_extra.go | 26 -- components/prompt/callback_extra.go | 26 -- components/retriever/callback_extra.go | 26 -- components/tool/callback_extra.go | 27 -- flow/agent/multiagent/host/callback.go | 4 +- flow/agent/react/callback.go | 8 +- flow/agent/react/react_test.go | 3 +- .../template => utils/callbacks}/template.go | 347 +++++++++++------- .../callbacks}/template_test.go | 133 ++----- 16 files changed, 254 insertions(+), 493 deletions(-) delete mode 100644 callbacks/template/default.go rename {callbacks/template => utils/callbacks}/template.go (58%) rename {callbacks/template => utils/callbacks}/template_test.go (65%) diff --git a/callbacks/handler_builder.go b/callbacks/handler_builder.go index 41c7146..04bd54e 100644 --- a/callbacks/handler_builder.go +++ b/callbacks/handler_builder.go @@ -34,7 +34,7 @@ import ( // if err != nil {...} // runnable.Invoke(ctx, params, compose.WithCallback(handler)) // => only implement functions which you want to override // -// Deprecated: In most situations, it is preferred to use template.NewHandlerHelper. Otherwise, use NewHandlerBuilder().OnStartFn()...Build(). +// Deprecated: In most situations, it is preferred to use callbacks.NewHandlerHelper. Otherwise, use NewHandlerBuilder().OnStartFn()...Build(). type HandlerBuilder struct { OnStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context OnEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context diff --git a/callbacks/interface.go b/callbacks/interface.go index def6efa..5ace52d 100644 --- a/callbacks/interface.go +++ b/callbacks/interface.go @@ -92,7 +92,7 @@ const ( // TimingChecker checks if the handler is needed for the given callback aspect timing. // It's recommended for callback handlers to implement this interface, but not mandatory. -// If a callback handler is created by using template.HandlerHelper or handlerBuilder, then this interface is automatically implemented. +// If a callback handler is created by using callbacks.HandlerHelper or handlerBuilder, then this interface is automatically implemented. // Eino's callback mechanism will try to use this interface to determine whether any handlers are needed for the given timing. // Also, the callback handler that is not needed for that timing will be skipped. type TimingChecker interface { diff --git a/callbacks/template/default.go b/callbacks/template/default.go deleted file mode 100644 index bf84c5c..0000000 --- a/callbacks/template/default.go +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package template - -import ( - "context" - - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/schema" -) - -// DefaultCallbackHandler is the default callback handler implementation, can be used for callback handler builder in template.HandlerHelper (for example, Graph, StateGraph, Chain, Lambda, etc.). -type DefaultCallbackHandler struct { - OnStart func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context - OnStartWithStreamInput func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context - OnEnd func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context - OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (d *DefaultCallbackHandler) Needed(_ context.Context, _ *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return d.OnStart != nil - case callbacks.TimingOnEnd: - return d.OnEnd != nil - case callbacks.TimingOnError: - return d.OnError != nil - case callbacks.TimingOnStartWithStreamInput: - return d.OnStartWithStreamInput != nil - case callbacks.TimingOnEndWithStreamOutput: - return d.OnEndWithStreamOutput != nil - default: - return false - } -} diff --git a/components/document/callback_extra_loader.go b/components/document/callback_extra_loader.go index 61a3265..ee1a167 100644 --- a/components/document/callback_extra_loader.go +++ b/components/document/callback_extra_loader.go @@ -17,8 +17,6 @@ package document import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -71,24 +69,3 @@ func ConvLoaderCallbackOutput(src callbacks.CallbackOutput) *LoaderCallbackOutpu return nil } } - -// LoaderCallbackHandler is the handler for the loader callback. -type LoaderCallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *LoaderCallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *LoaderCallbackOutput) context.Context - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *LoaderCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/document/callback_extra_transformer.go b/components/document/callback_extra_transformer.go index 474e363..cd98f4a 100644 --- a/components/document/callback_extra_transformer.go +++ b/components/document/callback_extra_transformer.go @@ -17,8 +17,6 @@ package document import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -68,24 +66,3 @@ func ConvTransformerCallbackOutput(src callbacks.CallbackOutput) *TransformerCal return nil } } - -// TransformerCallbackHandler is the handler for the transformer callback. -type TransformerCallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *TransformerCallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *TransformerCallbackOutput) context.Context - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *TransformerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/embedding/callback_extra.go b/components/embedding/callback_extra.go index d091ed2..cab5151 100644 --- a/components/embedding/callback_extra.go +++ b/components/embedding/callback_extra.go @@ -17,8 +17,6 @@ package embedding import ( - "context" - "github.com/cloudwego/eino/callbacks" ) @@ -97,24 +95,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the embedding callback. -type CallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/indexer/callback_extra.go b/components/indexer/callback_extra.go index a596261..88de567 100644 --- a/components/indexer/callback_extra.go +++ b/components/indexer/callback_extra.go @@ -17,8 +17,6 @@ package indexer import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -66,24 +64,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the indexer callback. -type CallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index a5f420c..8270c81 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -17,8 +17,6 @@ package model import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -100,27 +98,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the model callback. -type CallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*CallbackOutput]) context.Context - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - case callbacks.TimingOnEndWithStreamOutput: - return ch.OnEndWithStreamOutput != nil - default: - return false - } -} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 1d85e66..324a418 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -17,8 +17,6 @@ package prompt import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -70,27 +68,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the callback. -type CallbackHandler struct { - // OnStart is the callback function for the start of the callback. - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context - // OnEnd is the callback function for the end of the callback. - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context - // OnError is the callback function for the error of the callback. - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/retriever/callback_extra.go b/components/retriever/callback_extra.go index 76ade40..49ea421 100644 --- a/components/retriever/callback_extra.go +++ b/components/retriever/callback_extra.go @@ -17,8 +17,6 @@ package retriever import ( - "context" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/schema" ) @@ -74,27 +72,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the retriever callback. -type CallbackHandler struct { - // OnStart is the callback function for the start of the retriever. - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context - // OnEnd is the callback function for the end of the retriever. - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context - // OnError is the callback function for the error of the retriever. - OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/components/tool/callback_extra.go b/components/tool/callback_extra.go index bb8ee7a..73e7d23 100644 --- a/components/tool/callback_extra.go +++ b/components/tool/callback_extra.go @@ -17,10 +17,7 @@ package tool import ( - "context" - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/schema" ) // CallbackInput is the input for the tool callback. @@ -62,27 +59,3 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return nil } } - -// CallbackHandler is the handler for the tool callback. -type CallbackHandler struct { - OnStart func(ctx context.Context, info *callbacks.RunInfo, input *CallbackInput) context.Context - OnEnd func(ctx context.Context, info *callbacks.RunInfo, input *CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*CallbackOutput]) context.Context - OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context -} - -// Needed checks if the callback handler is needed for the given timing. -func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { - switch timing { - case callbacks.TimingOnStart: - return ch.OnStart != nil - case callbacks.TimingOnEnd: - return ch.OnEnd != nil - case callbacks.TimingOnEndWithStreamOutput: - return ch.OnEndWithStreamOutput != nil - case callbacks.TimingOnError: - return ch.OnError != nil - default: - return false - } -} diff --git a/flow/agent/multiagent/host/callback.go b/flow/agent/multiagent/host/callback.go index eca57f6..be3f8f7 100644 --- a/flow/agent/multiagent/host/callback.go +++ b/flow/agent/multiagent/host/callback.go @@ -21,10 +21,10 @@ import ( "io" "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/callbacks/template" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/flow/agent" "github.com/cloudwego/eino/schema" + template "github.com/cloudwego/eino/utils/callbacks" ) // MultiAgentCallback is the callback interface for host multi-agent. @@ -114,7 +114,7 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler { return ctx } - return template.NewHandlerHelper().ChatModel(&model.CallbackHandler{ + return template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{ OnEnd: onChatModelEnd, OnEndWithStreamOutput: onChatModelEndWithStreamOutput, }).Handler() diff --git a/flow/agent/react/callback.go b/flow/agent/react/callback.go index 63758f7..dc00f51 100644 --- a/flow/agent/react/callback.go +++ b/flow/agent/react/callback.go @@ -18,17 +18,15 @@ package react import ( "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/callbacks/template" - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/components/tool" + template "github.com/cloudwego/eino/utils/callbacks" ) // BuildAgentCallback builds a callback handler for agent. -// eg. +// e.g. // // callback := BuildAgentCallback(modelHandler, toolHandler) // agent, err := react.NewAgent(ctx, &AgentConfig{}) // agent.Generate(ctx, input, agent.WithComposeOptions(compose.WithCallbacks(callback))) -func BuildAgentCallback(modelHandler *model.CallbackHandler, toolHandler *tool.CallbackHandler) callbacks.Handler { +func BuildAgentCallback(modelHandler *template.ModelCallbackHandler, toolHandler *template.ToolCallbackHandler) callbacks.Handler { return template.NewHandlerHelper().ChatModel(modelHandler).Tool(toolHandler).Handler() } diff --git a/flow/agent/react/react_test.go b/flow/agent/react/react_test.go index bad99f3..19e0d2a 100644 --- a/flow/agent/react/react_test.go +++ b/flow/agent/react/react_test.go @@ -34,6 +34,7 @@ import ( "github.com/cloudwego/eino/flow/agent" mockModel "github.com/cloudwego/eino/internal/mock/components/model" "github.com/cloudwego/eino/schema" + template "github.com/cloudwego/eino/utils/callbacks" ) func TestReact(t *testing.T) { @@ -602,4 +603,4 @@ func randStr() string { return string(b) } -var callbackForTest = BuildAgentCallback(&model.CallbackHandler{}, &tool.CallbackHandler{}) +var callbackForTest = BuildAgentCallback(&template.ModelCallbackHandler{}, &template.ToolCallbackHandler{}) diff --git a/callbacks/template/template.go b/utils/callbacks/template.go similarity index 58% rename from callbacks/template/template.go rename to utils/callbacks/template.go index b53d5a6..6b42459 100644 --- a/callbacks/template/template.go +++ b/utils/callbacks/template.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package template +package callbacks import ( "context" @@ -38,31 +38,30 @@ import ( // and fallbackTemplate can be used to handle scenarios where none of the cases are hit as a fallback. func NewHandlerHelper() *HandlerHelper { return &HandlerHelper{ - composeTemplates: map[components.Component]*DefaultCallbackHandler{}, + composeTemplates: map[components.Component]callbacks.Handler{}, } } // HandlerHelper is a builder for creating a callbacks.Handler with specific handlers for different component types. -// create a handler with template.NewHandlerHelper(). +// create a handler with callbacks.NewHandlerHelper(). // eg. // // helper := template.NewHandlerHelper(). -// ChatModel(&model.CallbackHandler{}). -// Prompt(&prompt.CallbackHandler{}). +// ChatModel(&model.IndexerCallbackHandler{}). +// Prompt(&prompt.IndexerCallbackHandler{}). // Handler() // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *prompt.CallbackHandler - chatModelHandler *model.CallbackHandler - embeddingHandler *embedding.CallbackHandler - indexerHandler *indexer.CallbackHandler - retrieverHandler *retriever.CallbackHandler - loaderHandler *document.LoaderCallbackHandler - transformerHandler *document.TransformerCallbackHandler - toolHandler *tool.CallbackHandler - composeTemplates map[components.Component]*DefaultCallbackHandler - fallbackTemplate *DefaultCallbackHandler // execute when not matching any other condition + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -71,95 +70,77 @@ func (c *HandlerHelper) Handler() callbacks.Handler { } // Prompt sets the prompt handler for the handler helper, which will be called when the prompt component is executed. -func (c *HandlerHelper) Prompt(handler *prompt.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Prompt(handler *PromptCallbackHandler) *HandlerHelper { c.promptHandler = handler return c } // ChatModel sets the chat model handler for the handler helper, which will be called when the chat model component is executed. -func (c *HandlerHelper) ChatModel(handler *model.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) ChatModel(handler *ModelCallbackHandler) *HandlerHelper { c.chatModelHandler = handler return c } // Embedding sets the embedding handler for the handler helper, which will be called when the embedding component is executed. -func (c *HandlerHelper) Embedding(handler *embedding.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Embedding(handler *EmbeddingCallbackHandler) *HandlerHelper { c.embeddingHandler = handler return c } // Indexer sets the indexer handler for the handler helper, which will be called when the indexer component is executed. -func (c *HandlerHelper) Indexer(handler *indexer.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Indexer(handler *IndexerCallbackHandler) *HandlerHelper { c.indexerHandler = handler return c } // Retriever sets the retriever handler for the handler helper, which will be called when the retriever component is executed. -func (c *HandlerHelper) Retriever(handler *retriever.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Retriever(handler *RetrieverCallbackHandler) *HandlerHelper { c.retrieverHandler = handler return c } // Loader sets the loader handler for the handler helper, which will be called when the loader component is executed. -func (c *HandlerHelper) Loader(handler *document.LoaderCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Loader(handler *LoaderCallbackHandler) *HandlerHelper { c.loaderHandler = handler return c } // Transformer sets the transformer handler for the handler helper, which will be called when the transformer component is executed. -func (c *HandlerHelper) Transformer(handler *document.TransformerCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Transformer(handler *TransformerCallbackHandler) *HandlerHelper { c.transformerHandler = handler return c } // Tool sets the tool handler for the handler helper, which will be called when the tool component is executed. -func (c *HandlerHelper) Tool(handler *tool.CallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Tool(handler *ToolCallbackHandler) *HandlerHelper { c.toolHandler = handler return c } // Graph sets the graph handler for the handler helper, which will be called when the graph is executed. -func (c *HandlerHelper) Graph(handler *DefaultCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Graph(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfGraph] = handler return c } -// StateGraph sets the state graph handler for the handler helper, which will be called when the state graph is executed. -func (c *HandlerHelper) StateGraph(handler *DefaultCallbackHandler) *HandlerHelper { - c.composeTemplates[compose.ComponentOfStateGraph] = handler - return c -} - // Chain sets the chain handler for the handler helper, which will be called when the chain is executed. -func (c *HandlerHelper) Chain(handler *DefaultCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Chain(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfChain] = handler return c } -// Passthrough sets the passthrough handler for the handler helper, which will be called when the passthrough is executed. -func (c *HandlerHelper) Passthrough(handler *DefaultCallbackHandler) *HandlerHelper { - c.composeTemplates[compose.ComponentOfPassthrough] = handler - return c -} - // ToolsNode sets the tools node handler for the handler helper, which will be called when the tools node is executed. -func (c *HandlerHelper) ToolsNode(handler *DefaultCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) ToolsNode(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfToolsNode] = handler return c } // Lambda sets the lambda handler for the handler helper, which will be called when the lambda is executed. -func (c *HandlerHelper) Lambda(handler *DefaultCallbackHandler) *HandlerHelper { +func (c *HandlerHelper) Lambda(handler callbacks.Handler) *HandlerHelper { c.composeTemplates[compose.ComponentOfLambda] = handler return c } -// Fallback sets the fallback handler for the handler helper, which will be called when no other handlers are matched. -func (c *HandlerHelper) Fallback(handler *DefaultCallbackHandler) *HandlerHelper { - c.fallbackTemplate = handler - return c -} - type handlerTemplate struct { *HandlerHelper } @@ -171,66 +152,49 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return ctx } - match := false - switch info.Component { case components.ComponentOfPrompt: if c.promptHandler != nil && c.promptHandler.OnStart != nil { - match = true ctx = c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) } case components.ComponentOfChatModel: if c.chatModelHandler != nil && c.chatModelHandler.OnStart != nil { - match = true ctx = c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.OnStart != nil { - match = true ctx = c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) } case components.ComponentOfIndexer: if c.indexerHandler != nil && c.indexerHandler.OnStart != nil { - match = true ctx = c.indexerHandler.OnStart(ctx, info, indexer.ConvCallbackInput(input)) } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.OnStart != nil { - match = true ctx = c.retrieverHandler.OnStart(ctx, info, retriever.ConvCallbackInput(input)) } case components.ComponentOfLoader: if c.loaderHandler != nil && c.loaderHandler.OnStart != nil { - match = true ctx = c.loaderHandler.OnStart(ctx, info, document.ConvLoaderCallbackInput(input)) } case components.ComponentOfTransformer: if c.transformerHandler != nil && c.transformerHandler.OnStart != nil { - match = true ctx = c.transformerHandler.OnStart(ctx, info, document.ConvTransformerCallbackInput(input)) } case components.ComponentOfTool: if c.toolHandler != nil && c.toolHandler.OnStart != nil { - match = true ctx = c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) } case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, - compose.ComponentOfPassthrough, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnStart != nil { - match = true + if c.composeTemplates[info.Component] != nil { ctx = c.composeTemplates[info.Component].OnStart(ctx, info, input) } default: - - } - - if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnStart != nil { - ctx = c.fallbackTemplate.OnStart(ctx, info, input) + return ctx } return ctx @@ -243,66 +207,49 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return ctx } - match := false - switch info.Component { case components.ComponentOfPrompt: if c.promptHandler != nil && c.promptHandler.OnEnd != nil { - match = true ctx = c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) } case components.ComponentOfChatModel: if c.chatModelHandler != nil && c.chatModelHandler.OnEnd != nil { - match = true ctx = c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.OnEnd != nil { - match = true ctx = c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) } case components.ComponentOfIndexer: if c.indexerHandler != nil && c.indexerHandler.OnEnd != nil { - match = true ctx = c.indexerHandler.OnEnd(ctx, info, indexer.ConvCallbackOutput(output)) } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.OnEnd != nil { - match = true ctx = c.retrieverHandler.OnEnd(ctx, info, retriever.ConvCallbackOutput(output)) } case components.ComponentOfLoader: if c.loaderHandler != nil && c.loaderHandler.OnEnd != nil { - match = true ctx = c.loaderHandler.OnEnd(ctx, info, document.ConvLoaderCallbackOutput(output)) } case components.ComponentOfTransformer: if c.transformerHandler != nil && c.transformerHandler.OnEnd != nil { - match = true ctx = c.transformerHandler.OnEnd(ctx, info, document.ConvTransformerCallbackOutput(output)) } case components.ComponentOfTool: if c.toolHandler != nil && c.toolHandler.OnEnd != nil { - match = true ctx = c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) } case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, - compose.ComponentOfPassthrough, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnEnd != nil { - match = true + if c.composeTemplates[info.Component] != nil { ctx = c.composeTemplates[info.Component].OnEnd(ctx, info, output) } default: - - } - - if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnEnd != nil { - ctx = c.fallbackTemplate.OnEnd(ctx, info, output) + return ctx } return ctx @@ -315,47 +262,37 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return ctx } - match := false - switch info.Component { case components.ComponentOfPrompt: if c.promptHandler != nil && c.promptHandler.OnError != nil { - match = true ctx = c.promptHandler.OnError(ctx, info, err) } case components.ComponentOfChatModel: if c.chatModelHandler != nil && c.chatModelHandler.OnError != nil { - match = true ctx = c.chatModelHandler.OnError(ctx, info, err) } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.OnError != nil { - match = true ctx = c.embeddingHandler.OnError(ctx, info, err) } case components.ComponentOfIndexer: if c.indexerHandler != nil && c.indexerHandler.OnError != nil { - match = true ctx = c.indexerHandler.OnError(ctx, info, err) } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.OnError != nil { - match = true ctx = c.retrieverHandler.OnError(ctx, info, err) } case components.ComponentOfLoader: if c.loaderHandler != nil && c.loaderHandler.OnError != nil { - match = true ctx = c.loaderHandler.OnError(ctx, info, err) } case components.ComponentOfTransformer: if c.transformerHandler != nil && c.transformerHandler.OnError != nil { - match = true ctx = c.transformerHandler.OnError(ctx, info, err) } case components.ComponentOfTool: if c.toolHandler != nil && c.toolHandler.OnError != nil { - match = true ctx = c.toolHandler.OnError(ctx, info, err) } case compose.ComponentOfGraph, @@ -365,16 +302,11 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnError != nil { - match = true + if c.composeTemplates[info.Component] != nil { ctx = c.composeTemplates[info.Component].OnError(ctx, info, err) } default: - - } - - if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnError != nil { - ctx = c.fallbackTemplate.OnError(ctx, info, err) + return ctx } return ctx @@ -397,22 +329,15 @@ func (c *handlerTemplate) OnStartWithStreamInput(ctx context.Context, info *call switch info.Component { // currently no components.Component receive stream as input case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, - compose.ComponentOfPassthrough, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnStartWithStreamInput != nil { + if c.composeTemplates[info.Component] != nil { match = true ctx = c.composeTemplates[info.Component].OnStartWithStreamInput(ctx, info, input) } default: - - } - - if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnStartWithStreamInput != nil { - match = true - ctx = c.fallbackTemplate.OnStartWithStreamInput(ctx, info, input) + return ctx } return ctx @@ -451,23 +376,16 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb })) } case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, - compose.ComponentOfPassthrough, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnEndWithStreamOutput != nil { + if c.composeTemplates[info.Component] != nil { match = true ctx = c.composeTemplates[info.Component].OnEndWithStreamOutput(ctx, info, output) } default: - - } - - if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnEndWithStreamOutput != nil { - match = true - ctx = c.fallbackTemplate.OnEndWithStreamOutput(ctx, info, output) + return ctx } return ctx @@ -509,22 +427,199 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t return true } case compose.ComponentOfGraph, - compose.ComponentOfStateGraph, compose.ComponentOfChain, - compose.ComponentOfPassthrough, compose.ComponentOfToolsNode, compose.ComponentOfLambda: - template := c.composeTemplates[info.Component] - if template != nil && template.Needed(ctx, info, timing) { - return true + handler := c.composeTemplates[info.Component] + if handler != nil { + checker, ok := handler.(callbacks.TimingChecker) + if !ok || checker.Needed(ctx, info, timing) { + return true + } } default: + return false + } + + return false +} + +// LoaderCallbackHandler is the handler for the loader callback. +type LoaderCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.LoaderCallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *LoaderCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} +// TransformerCallbackHandler is the handler for the transformer callback. +type TransformerCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.TransformerCallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *TransformerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false } +} + +// EmbeddingCallbackHandler is the handler for the embedding callback. +type EmbeddingCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *embedding.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *embedding.CallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} - if c.fallbackTemplate != nil { - return c.fallbackTemplate.Needed(ctx, info, timing) +// Needed checks if the callback handler is needed for the given timing. +func (ch *EmbeddingCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false } +} - return false +// IndexerCallbackHandler is the handler for the indexer callback. +type IndexerCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *indexer.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *indexer.CallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *IndexerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// ModelCallbackHandler is the handler for the model callback. +type ModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *ModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// PromptCallbackHandler is the handler for the callback. +type PromptCallbackHandler struct { + // OnStart is the callback function for the start of the callback. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the callback. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the callback. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *PromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// RetrieverCallbackHandler is the handler for the retriever callback. +type RetrieverCallbackHandler struct { + // OnStart is the callback function for the start of the retriever. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context + // OnEnd is the callback function for the end of the retriever. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context + // OnError is the callback function for the error of the retriever. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *RetrieverCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// ToolCallbackHandler is the handler for the tool callback. +type ToolCallbackHandler struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*tool.CallbackOutput]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *ToolCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } } diff --git a/callbacks/template/template_test.go b/utils/callbacks/template_test.go similarity index 65% rename from callbacks/template/template_test.go rename to utils/callbacks/template_test.go index a6c86e4..a621f58 100644 --- a/callbacks/template/template_test.go +++ b/utils/callbacks/template_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package template +package callbacks import ( "context" @@ -37,10 +37,10 @@ import ( ) func TestNewComponentTemplate(t *testing.T) { - t.Run("test no fallback", func(t *testing.T) { + t.Run("TestNewComponentTemplate", func(t *testing.T) { cnt := 0 tpl := NewHandlerHelper() - tpl.ChatModel(&model.CallbackHandler{ + tpl.ChatModel(&ModelCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { cnt++ return ctx @@ -58,7 +58,7 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }}). - Embedding(&embedding.CallbackHandler{ + Embedding(&EmbeddingCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { cnt++ return ctx @@ -72,7 +72,7 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }, }). - Prompt(&prompt.CallbackHandler{ + Prompt(&PromptCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { cnt++ return ctx @@ -86,7 +86,7 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }, }). - Retriever(&retriever.CallbackHandler{ + Retriever(&RetrieverCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { cnt++ return ctx @@ -100,7 +100,7 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }, }). - Tool(&tool.CallbackHandler{ + Tool(&ToolCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context { cnt++ return ctx @@ -118,32 +118,32 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }, }). - Lambda(&DefaultCallbackHandler{ - OnStart: func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + Lambda(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { cnt++ return ctx - }, - OnStartWithStreamInput: func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { + }). + OnStartWithStreamInputFn(func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { input.Close() cnt++ return ctx - }, - OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { cnt++ return ctx - }, - OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { + }). + OnEndWithStreamOutputFn(func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { output.Close() cnt++ return ctx - }, - OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { cnt++ return ctx - }, - }) + }).Build()). + Handler() - typs := []components.Component{ + types := []components.Component{ components.ComponentOfPrompt, components.ComponentOfLoaderSplitter, components.ComponentOfChatModel, @@ -155,7 +155,7 @@ func TestNewComponentTemplate(t *testing.T) { handler := tpl.Handler() ctx := context.Background() - for _, typ := range typs { + for _, typ := range types { handler.OnStart(ctx, &callbacks.RunInfo{Component: typ}, nil) handler.OnEnd(ctx, &callbacks.RunInfo{Component: typ}, nil) handler.OnError(ctx, &callbacks.RunInfo{Component: typ}, fmt.Errorf("mock err")) @@ -192,7 +192,7 @@ func TestNewComponentTemplate(t *testing.T) { callbacks.OnStart(ctx, nil) assert.Equal(t, 24, cnt) - tpl.Transformer(&document.TransformerCallbackHandler{ + tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { cnt++ return ctx @@ -205,7 +205,7 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }, - }).Indexer(&indexer.CallbackHandler{ + }).Indexer(&IndexerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { cnt++ return ctx @@ -218,7 +218,7 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }, - }).Loader(&document.LoaderCallbackHandler{ + }).Loader(&LoaderCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { cnt++ return ctx @@ -247,89 +247,4 @@ func TestNewComponentTemplate(t *testing.T) { callbacks.OnEnd(ctx, nil) assert.Equal(t, 27, cnt) }) - - t.Run("test fallback", func(t *testing.T) { - cnt, cntf := 0, 0 - tpl := NewHandlerHelper(). - Retriever(&retriever.CallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { - cnt++ - return ctx - }, - OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context { - cnt++ - return ctx - }, - OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { - cnt++ - return ctx - }, - }). - Fallback(&DefaultCallbackHandler{ - OnStart: func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { - cntf++ - return ctx - }, - OnStartWithStreamInput: func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { - input.Close() - cntf++ - return ctx - }, - OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { - cntf++ - return ctx - }, - OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { - output.Close() - cntf++ - return ctx - }, - OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { - cntf++ - return ctx - }, - }) - - handler := tpl.Handler() - ctx := context.Background() - handler.OnStart(ctx, &callbacks.RunInfo{ - Component: compose.ComponentOfLambda, - }, nil) - - handler.OnEnd(ctx, &callbacks.RunInfo{ - Component: compose.ComponentOfLambda, - }, nil) - - handler.OnError(ctx, &callbacks.RunInfo{ - Component: compose.ComponentOfLambda, - }, fmt.Errorf("mock err")) - - sir, siw := schema.Pipe[callbacks.CallbackInput](1) - siw.Close() - handler.OnStartWithStreamInput(ctx, &callbacks.RunInfo{ - Component: compose.ComponentOfLambda, - }, sir) - - sor, sow := schema.Pipe[callbacks.CallbackOutput](1) - sow.Close() - handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{ - Component: compose.ComponentOfLambda, - }, sor) - - assert.Equal(t, 0, cnt) - assert.Equal(t, 5, cntf) - - ctx = context.Background() - ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, handler) - callbacks.OnStart(ctx, nil) - callbacks.OnEnd(ctx, nil) - callbacks.OnError(ctx, nil) - callbacks.OnStartWithStreamInput(ctx, &schema.StreamReader[callbacks.CallbackInput]{}) - callbacks.OnEndWithStreamOutput(ctx, &schema.StreamReader[callbacks.CallbackOutput]{}) - assert.Equal(t, 10, cntf) - - ctx = callbacks.SetRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}) - callbacks.OnStart(ctx, nil) - assert.Equal(t, 1, cnt) - }) }