Skip to content

Commit

Permalink
feat: remove state chain/graph
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn committed Dec 24, 2024
1 parent 1f94cb2 commit 91df0ea
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 74 deletions.
2 changes: 1 addition & 1 deletion compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (g *graph) component() component {
}

func isChain(cmp component) bool {
return cmp == ComponentOfChain || cmp == ComponentOfStateChain
return cmp == ComponentOfChain
}

// ErrGraphCompiled is returned when attempting to modify a graph after it has been compiled
Expand Down
2 changes: 1 addition & 1 deletion compose/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ func TestGraphCompileCallback(t *testing.T) {
t.Run("graph compile callback", func(t *testing.T) {
type s struct{}

g := NewStateGraph[map[string]any, map[string]any, *s](func(ctx context.Context) *s { return &s{} })
g := NewGraph[map[string]any, map[string]any](WithGenLocalState(func(ctx context.Context) *s { return &s{} }))

lambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "node1", nil
Expand Down
66 changes: 0 additions & 66 deletions compose/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,72 +25,6 @@ import (
"github.com/cloudwego/eino/utils/generic"
)

// NewStateGraph creates a new state graph. It requires a func of GenLocalState to generate the state.
//
// Deprecated: NewStateGraph is deprecated and will be removed in a future version.
// Use NewGraph with WithGenLocalState option instead:
//
// // Instead of:
// graph := NewStateGraph[Input, Output, State](genStateFunc)
//
// // Use:
// graph := NewGraph[Input, Output](WithGenLocalState(genStateFunc))
func NewStateGraph[I, O, S any](gen GenLocalState[S]) *StateGraph[I, O, S] {
sg := &StateGraph[I, O, S]{NewGraph[I, O](WithGenLocalState(gen))}

sg.cmp = ComponentOfStateGraph

return sg
}

// StateGraph is a graph that shares state between nodes. It's useful when you want to share some data across nodes.
//
// Deprecated: StateGraph is deprecated and will be removed in a future version.
// Use Graph with WithGenLocalState option instead:
//
// // Instead of:
// graph := NewStateGraph[Input, Output, State](genStateFunc)
//
// // Use:
// graph := NewGraph[Input, Output](WithGenLocalState(genStateFunc))
type StateGraph[I, O, S any] struct {
*Graph[I, O]
}

// NewStateChain creates a new state chain. It requires a func of GenLocalState to generate the state.
//
// Deprecated: NewStateChain is deprecated and will be removed in a future version.
// Use NewChain with WithGenLocalState option instead:
//
// // Instead of:
// chain := NewStateChain[Input, Output, State](genStateFunc)
//
// // Use:
// chain := NewChain[Input, Output](WithGenLocalState(genStateFunc))
func NewStateChain[I, O, S any](gen GenLocalState[S]) *StateChain[I, O, S] {
sc := &StateChain[I, O, S]{NewChain[I, O](WithGenLocalState(gen))}

sc.gg.cmp = ComponentOfStateChain

return sc
}

// StateChain is a chain that shares state between nodes. State is shared between nodes in the chain.
// It's useful when you want to share some data across nodes in a chain.
// you can use WithPreHandler and WithPostHandler to do something with state of this chain.
//
// Deprecated: StateChain is deprecated and will be removed in a future version.
// Use Chain with WithGenLocalState option instead:
//
// // Instead of:
// chain := NewStateChain[Input, Output, State](genStateFunc)
//
// // Use:
// chain := NewChain[Input, Output](WithGenLocalState(genStateFunc))
type StateChain[I, O, S any] struct {
*Chain[I, O]
}

// GenLocalState is a function that generates the state.
type GenLocalState[S any] func(ctx context.Context) (state S)

Expand Down
8 changes: 4 additions & 4 deletions compose/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestStateGraphWithEdge(t *testing.T) {
return &testState{}
}

sg := NewStateGraph[string, string, *testState](gen)
sg := NewGraph[string, string](WithGenLocalState(gen))

l1 := InvokableLambda(func(ctx context.Context, in string) (out midStr, err error) {
return midStr("InvokableLambda: " + in), nil
Expand Down Expand Up @@ -221,9 +221,9 @@ func TestStateChain(t *testing.T) {
Field1 string
Field2 string
}
sc := NewStateChain[string, string, *testState](func(ctx context.Context) (state *testState) {
sc := NewChain[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) {
return &testState{}
})
}))

r, err := sc.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
s, err := GetState[*testState](ctx)
Expand Down Expand Up @@ -259,7 +259,7 @@ func TestStreamState(t *testing.T) {
}
ctx := context.Background()
s := &testState{Field1: "1"}
g := NewStateGraph[string, string, *testState](func(ctx context.Context) (state *testState) { return s })
g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state *testState) { return s }))
err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) {
return input, nil
}), WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) {
Expand Down
1 change: 0 additions & 1 deletion compose/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ type component = components.Component
const (
ComponentOfUnknown component = "Unknown"
ComponentOfGraph component = "Graph"
ComponentOfStateGraph component = "StateGraph"
ComponentOfChain component = "Chain"
ComponentOfStateChain component = "StateChain"
ComponentOfPassthrough component = "Passthrough"
Expand Down
1 change: 0 additions & 1 deletion utils/callbacks/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo,
ctx = c.toolHandler.OnError(ctx, info, err)
}
case compose.ComponentOfGraph,
compose.ComponentOfStateGraph,
compose.ComponentOfChain,
compose.ComponentOfPassthrough,
compose.ComponentOfToolsNode,
Expand Down

0 comments on commit 91df0ea

Please sign in to comment.