Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 17, 2023
1 parent 674c8bb commit dec1985
Show file tree
Hide file tree
Showing 25 changed files with 310 additions and 123 deletions.
2 changes: 1 addition & 1 deletion _examples/retrieval_qa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
log.Fatal(err)
}

result, err := chain.Run(context.Background(), retrievalQAChain, "Why don't scientists trust atoms?")
result, err := retrievalQAChain.Run(context.Background(), "Why don't scientists trust atoms?")
if err != nil {
log.Fatal(err)
}
Expand Down
14 changes: 14 additions & 0 deletions callback/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,17 @@ func (m *Manager) OnChainEnd(outputs *schema.ChainValues) error {

return nil
}

func (m *Manager) OnChainError(chainError error) error {
for _, c := range m.callbacks {
if m.verbose || c.AlwaysVerbose() {
if err := c.OnChainError(chainError); err != nil {
if c.RaiseError() {
return err
}
}
}
}

return nil
}
78 changes: 49 additions & 29 deletions chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,87 @@ import (
"context"
"strings"

"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/schema"
)

type callbackOptions struct {
Callbacks []schema.Callback
Verbose bool
}

type callFunc func(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error)

type chain struct {
callFunc callFunc
inputKeys []string
outputKeys []string
type baseChain struct {
chainName string
callFunc callFunc
inputKeys []string
outputKeys []string
memory schema.Memory
callbackOptions *callbackOptions
}

func newChain(callFunc callFunc, inputKeys []string, outputKeys []string) *chain {
return &chain{
callFunc: callFunc,
inputKeys: inputKeys,
outputKeys: outputKeys,
func (bc *baseChain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
cm := callback.NewManager(bc.callbackOptions.Callbacks, bc.callbackOptions.Verbose)

if err := cm.OnChainStart(bc.chainName, &inputs); err != nil {
return nil, err
}
}

func (c *chain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
return c.callFunc(ctx, inputs)
output, err := bc.callFunc(ctx, inputs)
if err != nil {
if cbError := cm.OnChainError(err); cbError != nil {
return nil, cbError
}

return nil, err
}

if err := cm.OnChainEnd(&output); err != nil {
return nil, err
}

return output, nil
}

func (c *chain) Run(ctx context.Context, input any) (string, error) {
if len(c.inputKeys) != 1 {
func (bc *baseChain) Run(ctx context.Context, input any) (string, error) {
if len(bc.inputKeys) != 1 {
return "", ErrMultipleInputsInRun
}

if len(c.outputKeys) != 1 {
if len(bc.outputKeys) != 1 {
return "", ErrMultipleOutputsInRun
}

inputValues := map[string]any{c.inputKeys[0]: input}
inputValues := map[string]any{bc.inputKeys[0]: input}

// TODO
if bc.memory != nil {
_, _ = bc.memory.LoadMemoryVariables(inputValues)
}

outputValues, err := c.Call(ctx, inputValues)
outputValues, err := bc.Call(ctx, inputValues)
if err != nil {
return "", err
}

outputValue, ok := outputValues[c.outputKeys[0]].(string)
outputValue, ok := outputValues[bc.outputKeys[0]].(string)
if !ok {
return "", ErrWrongOutputTypeInRun
}

return strings.TrimSpace(outputValue), nil
}

func (c *chain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schema.ChainValues, error) {
func (bc *baseChain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schema.ChainValues, error) {
chainValues := []schema.ChainValues{}

for _, input := range inputs {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
vals, err := c.Call(ctx, input)
vals, err := bc.Call(ctx, input)
if err != nil {
return nil, err
}
Expand All @@ -72,16 +97,11 @@ func (c *chain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schem
}

// InputKeys returns the expected input keys.
func (c *chain) InputKeys() []string {
return c.inputKeys
func (bc *baseChain) InputKeys() []string {
return bc.inputKeys
}

// OutputKeys returns the output keys the chain will return.
func (c *chain) OutputKeys() []string {
return c.outputKeys
}

type callbackOptions struct {
Callbacks []schema.Callback
Verbose bool
func (bc *baseChain) OutputKeys() []string {
return bc.outputKeys
}
57 changes: 57 additions & 0 deletions chain/conversation_chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package chain

import (
"context"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/memory"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

type ConversationChainOptions struct {
*callbackOptions
Memory schema.Memory
OutputParser schema.OutputParser[any]
}

type ConversationChain struct {
*baseChain
llm schema.LLM
prompt *prompt.Template
opts ConversationChainOptions
}

func NewConversationChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *ConversationChainOptions)) (*ConversationChain, error) {
opts := ConversationChainOptions{
Memory: memory.NewConversationBuffer(),
callbackOptions: &callbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
fn(&opts)
}

conversationChain := &ConversationChain{
prompt: prompt,
llm: llm,
opts: opts,
}

conversationChain.baseChain = &baseChain{
chainName: "ConversationChain",
callFunc: conversationChain.call,
inputKeys: []string{"input"},
outputKeys: []string{"response"},
memory: opts.Memory,
callbackOptions: opts.callbackOptions,
}

return conversationChain, nil
}

func (c *ConversationChain) call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
return nil, nil
}
21 changes: 18 additions & 3 deletions chain/llm_bash_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/integration"
"github.com/hupe1980/golc/outputparser"
"github.com/hupe1980/golc/prompt"
Expand Down Expand Up @@ -31,21 +32,29 @@ That is the format. Begin!
Question: {{.question}}`

type LLMBashChainOptions struct {
*callbackOptions
InputKey string
OutputKey string
}

type LLMBashChain struct {
*chain
*baseChain
llmChain *LLMChain
bashProcess *integration.BashProcess
opts LLMBashChainOptions
}

func NewLLMBashChain(llmChain *LLMChain) (*LLMBashChain, error) {
func NewLLMBashChain(llmChain *LLMChain, optFns ...func(o *LLMBashChainOptions)) (*LLMBashChain, error) {
opts := LLMBashChainOptions{
InputKey: "question",
OutputKey: "answer",
callbackOptions: &callbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
fn(&opts)
}

bp, err := integration.NewBashProcess()
Expand All @@ -59,7 +68,13 @@ func NewLLMBashChain(llmChain *LLMChain) (*LLMBashChain, error) {
opts: opts,
}

bash.chain = newChain(bash.call, []string{opts.InputKey}, []string{opts.OutputKey})
bash.baseChain = &baseChain{
chainName: "LLMBashChain",
callFunc: bash.call,
inputKeys: []string{opts.InputKey},
outputKeys: []string{opts.OutputKey},
callbackOptions: opts.callbackOptions,
}

return bash, nil
}
Expand Down
58 changes: 25 additions & 33 deletions chain/llm_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ import (
"context"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

type LLMChainOptions struct {
callbackOptions
*callbackOptions
Memory schema.Memory
OutputKey string
OutputParser schema.OutputParser[any]
}

type LLMChain struct {
*chain
*baseChain
llm schema.LLM
prompt *prompt.Template
opts LLMChainOptions
Expand All @@ -25,7 +25,7 @@ type LLMChain struct {
func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMChainOptions)) (*LLMChain, error) {
opts := LLMChainOptions{
OutputKey: "text",
callbackOptions: callbackOptions{
callbackOptions: &callbackOptions{
Verbose: golc.Verbose,
},
}
Expand All @@ -40,32 +40,29 @@ func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMC
opts: opts,
}

llmChain.chain = newChain(llmChain.call, prompt.InputVariables(), []string{opts.OutputKey})
llmChain.baseChain = &baseChain{
chainName: "LLMChain",
callFunc: llmChain.call,
inputKeys: prompt.InputVariables(),
outputKeys: []string{opts.OutputKey},
memory: opts.Memory,
callbackOptions: opts.callbackOptions,
}

return llmChain, nil
}

func (c *LLMChain) Type() string {
return "llm_chain"
}

func (c *LLMChain) Predict(ctx context.Context, values schema.ChainValues) (string, error) {
output, err := c.Call(ctx, values)
func (c *LLMChain) Predict(ctx context.Context, inputs schema.ChainValues) (string, error) {
output, err := c.Call(ctx, inputs)
if err != nil {
return "", err
}

return output[c.opts.OutputKey].(string), err
}

func (c *LLMChain) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
cm := callback.NewManager(c.opts.Callbacks, c.opts.Verbose)

if err := cm.OnChainStart("LLMChain", &values); err != nil {
return nil, err
}

promptValue, err := c.prompt.FormatPrompt(values)
func (c *LLMChain) call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
promptValue, err := c.prompt.FormatPrompt(inputs)
if err != nil {
return nil, err
}
Expand All @@ -77,26 +74,21 @@ func (c *LLMChain) call(ctx context.Context, values schema.ChainValues) (schema.
return nil, err
}

output, err := c.getFinalOutput(res.Generations[0])
if err != nil {
return nil, err
}

if err := cm.OnChainEnd(&schema.ChainValues{"outputs": output}); err != nil {
return nil, err
}

return schema.ChainValues{
c.opts.OutputKey: output,
c.opts.OutputKey: c.getFinalOutput(res.Generations),
}, nil
}

func (c *LLMChain) Prompt() *prompt.Template {
return c.prompt
}

func (c *LLMChain) getFinalOutput(generations []*schema.Generation) (any, error) { // nolint unparam
completion := generations[0].Text
// TODO Outputparser
return completion, nil
func (c *LLMChain) getFinalOutput(generations [][]*schema.Generation) string {
output := []string{}
for _, generation := range generations {
// Get the text of the top generated string.
output = append(output, generation[0].Text)
}

return output[0]
}
Loading

0 comments on commit dec1985

Please sign in to comment.