diff --git a/genkit-tools/cli/src/commands/flow-batch-run.ts b/genkit-tools/cli/src/commands/flow-batch-run.ts index dd2dfdc05..a739cec46 100644 --- a/genkit-tools/cli/src/commands/flow-batch-run.ts +++ b/genkit-tools/cli/src/commands/flow-batch-run.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { FlowInvokeEnvelopeMessage } from '@genkit-ai/tools-common'; import { logger } from '@genkit-ai/tools-common/utils'; import { Command } from 'commander'; import { readFile, writeFile } from 'fs/promises'; @@ -59,13 +58,11 @@ export const flowBatchRun = new Command('flow:batchRun') logger.info(`Running '/flow/${flowName}'...`); let response = await manager.runAction({ key: `/flow/${flowName}`, - input: { - start: { - input: data, - labels: options.label ? { batchRun: options.label } : undefined, - auth: options.auth ? JSON.parse(options.auth) : undefined, - }, - } as FlowInvokeEnvelopeMessage, + input: data, + context: options.auth ? JSON.parse(options.auth) : undefined, + telemetryLabels: options.label + ? { batchRun: options.label } + : undefined, }); logger.info( 'Result:\n' + JSON.stringify(response.result, undefined, ' ') diff --git a/genkit-tools/cli/src/commands/flow-run.ts b/genkit-tools/cli/src/commands/flow-run.ts index b0491a14a..f7954f0a7 100644 --- a/genkit-tools/cli/src/commands/flow-run.ts +++ b/genkit-tools/cli/src/commands/flow-run.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { FlowInvokeEnvelopeMessage } from '@genkit-ai/tools-common'; import { logger } from '@genkit-ai/tools-common/utils'; import { Command } from 'commander'; import { writeFile } from 'fs/promises'; @@ -50,12 +49,8 @@ export const flowRun = new Command('flow:run') await manager.runAction( { key: `/flow/${flowName}`, - input: { - start: { - input: data ? JSON.parse(data) : undefined, - }, - auth: options.auth ? JSON.parse(options.auth) : undefined, - } as FlowInvokeEnvelopeMessage, + input: data ? JSON.parse(data) : undefined, + context: options.auth ? JSON.parse(options.auth) : undefined, }, options.stream ? (chunk) => console.log(JSON.stringify(chunk, undefined, ' ')) diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 28654afaf..0b4ea4266 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -25,7 +25,6 @@ import { EvalKeyAugments, EvalRun, EvalRunKey, - FlowActionInputSchema, GenerateRequest, GenerateRequestSchema, GenerateResponseSchema, @@ -257,15 +256,10 @@ async function runFlowAction(params: { const { manager, actionRef, testCase, auth } = { ...params }; let state: InferenceRunState; try { - const flowInput = FlowActionInputSchema.parse({ - start: { - input: testCase.input, - }, - auth: auth ? JSON.parse(auth) : undefined, - }); const runActionResponse = await manager.runAction({ key: actionRef, - input: flowInput, + input: testCase.input, + context: auth ? JSON.parse(auth) : undefined, }); state = { ...testCase, diff --git a/genkit-tools/common/src/manager/manager.ts b/genkit-tools/common/src/manager/manager.ts index 77269cd39..1f4d15109 100644 --- a/genkit-tools/common/src/manager/manager.ts +++ b/genkit-tools/common/src/manager/manager.ts @@ -42,6 +42,7 @@ import { const STREAM_DELIMITER = '\n'; const HEALTH_CHECK_INTERVAL = 5000; +export const GENKIT_REFLECTION_API_SPEC_VERSION = 1; interface RuntimeManagerOptions { /** URL of the telemetry server. */ @@ -278,6 +279,7 @@ export class RuntimeManager { try { await axios.post(`${runtime.reflectionServerUrl}/api/notify`, { telemetryServerUrl: this.telemetryServerUrl, + reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION, }); } catch (error) { logger.error(`Failed to notify runtime ${runtime.id}: ${error}`); @@ -326,6 +328,27 @@ export class RuntimeManager { if (isValidRuntimeInfo(runtimeInfo)) { const fileName = path.basename(filePath); if (await checkServerHealth(runtimeInfo.reflectionServerUrl)) { + if ( + runtimeInfo.reflectionApiSpecVersion != + GENKIT_REFLECTION_API_SPEC_VERSION + ) { + if ( + !runtimeInfo.reflectionApiSpecVersion || + runtimeInfo.reflectionApiSpecVersion < + GENKIT_REFLECTION_API_SPEC_VERSION + ) { + logger.warn( + 'Genkit CLI is newer than runtime library. Some feature may not be supported. ' + + 'Consider upgrading your runtime library version (debug info: expected ' + + `${GENKIT_REFLECTION_API_SPEC_VERSION}, got ${runtimeInfo.reflectionApiSpecVersion}).` + ); + } else { + logger.error( + 'Genkit CLI version is outdated. Please update `genkit-cli` to the latest version.' + ); + process.exit(1); + } + } this.filenameToRuntimeMap[fileName] = runtimeInfo; this.idToFileMap[runtimeInfo.id] = fileName; this.eventEmitter.emit(RuntimeEvent.ADD, runtimeInfo); diff --git a/genkit-tools/common/src/manager/types.ts b/genkit-tools/common/src/manager/types.ts index a50bcc51d..3fb9b0325 100644 --- a/genkit-tools/common/src/manager/types.ts +++ b/genkit-tools/common/src/manager/types.ts @@ -38,6 +38,10 @@ export interface RuntimeInfo { timestamp: string; /** Display name for the project, typically basename of the root folder */ projectName?: string; + /** Genkit runtime library version. Ex: nodejs/0.9.5 or go/0.2.0 */ + genkitVersion?: string; + /** Reflection API specification version. Ex: 1 */ + reflectionApiSpecVersion?: number; } export enum RuntimeEvent { diff --git a/genkit-tools/common/src/server/server.ts b/genkit-tools/common/src/server/server.ts index 73237a859..6a2da6b25 100644 --- a/genkit-tools/common/src/server/server.ts +++ b/genkit-tools/common/src/server/server.ts @@ -64,7 +64,7 @@ export function startServer(manager: RuntimeManager, port: number) { }); app.post('/api/streamAction', bodyParser.json(), async (req, res) => { - const { key, input } = req.body; + const { key, input, context } = req.body; res.writeHead(200, { 'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Headers': 'Content-Type', @@ -72,7 +72,7 @@ export function startServer(manager: RuntimeManager, port: number) { 'Transfer-Encoding': 'chunked', }); - const result = await manager.runAction({ key, input }, (chunk) => { + const result = await manager.runAction({ key, input, context }, (chunk) => { res.write(JSON.stringify(chunk) + '\n'); }); res.write(JSON.stringify(result)); diff --git a/genkit-tools/common/src/types/apis.ts b/genkit-tools/common/src/types/apis.ts index fe0f3128a..bd300cc73 100644 --- a/genkit-tools/common/src/types/apis.ts +++ b/genkit-tools/common/src/types/apis.ts @@ -61,6 +61,14 @@ export const RunActionRequestSchema = z.object({ .any() .optional() .describe('An input with the type that this action expects.'), + context: z + .any() + .optional() + .describe('Additional runtime context data (ex. auth context data).'), + telemetryLabels: z + .record(z.string(), z.string()) + .optional() + .describe('Labels to be applied to telemetry data.'), }); export type RunActionRequest = z.infer; diff --git a/genkit-tools/common/src/types/flow.ts b/genkit-tools/common/src/types/flow.ts deleted file mode 100644 index 2f0a42bd9..000000000 --- a/genkit-tools/common/src/types/flow.ts +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { extendZodWithOpenApi } from '@asteasolutions/zod-to-openapi'; -import * as z from 'zod'; - -extendZodWithOpenApi(z); - -// NOTE: Keep this file in sync with genkit/flow/src/types.ts! -// Eventually tools will be source of truth for these types (by generating a -// JSON schema) but until then this file must be manually kept in sync - -/** - * The message format used by the flow task queue and control interface. - */ -export const FlowInvokeEnvelopeMessageSchema = z.object({ - // Start new flow. - start: z - .object({ - input: z.unknown().optional(), - labels: z.record(z.string(), z.string()).optional(), - }) - .optional(), - // Schedule new flow. - schedule: z - .object({ - input: z.unknown().optional(), - delay: z.number().optional(), - }) - .optional(), - // Run previously scheduled flow. - runScheduled: z - .object({ - flowId: z.string(), - }) - .optional(), - // Retry failed step (only if step is setup for retry) - retry: z - .object({ - flowId: z.string(), - }) - .optional(), - // Resume an interrupted flow. - resume: z - .object({ - flowId: z.string(), - payload: z.unknown().optional(), - }) - .optional(), - // State check for a given flow ID. No side effects, can be used to check flow state. - state: z - .object({ - flowId: z.string(), - }) - .optional(), -}); -export type FlowInvokeEnvelopeMessage = z.infer< - typeof FlowInvokeEnvelopeMessageSchema ->; - -export const FlowActionInputSchema = FlowInvokeEnvelopeMessageSchema.extend({ - auth: z.unknown().optional(), -}); - -export const FlowStateExecutionSchema = z.object({ - startTime: z - .number() - .optional() - .describe('start time in milliseconds since the epoch'), - endTime: z - .number() - .optional() - .describe('end time in milliseconds since the epoch'), - traceIds: z.array(z.string()), -}); -export type FlowStateExecution = z.infer; - -export const FlowResponseSchema = z.object({ - response: z.unknown().nullable(), -}); -export const FlowErrorSchema = z.object({ - error: z.string().optional(), - stacktrace: z.string().optional(), -}); -export type FlowError = z.infer; - -export const FlowResultSchema = FlowResponseSchema.and(FlowErrorSchema); diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index 078d050f2..0ddcdc1ee 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -21,7 +21,6 @@ export * from './apis'; export * from './env'; export * from './eval'; export * from './evaluators'; -export * from './flow'; export * from './model'; export * from './prompt'; export * from './retrievers'; diff --git a/genkit-tools/common/tests/utils/trace.ts b/genkit-tools/common/tests/utils/trace.ts index 1b19c39ab..a22c24c0b 100644 --- a/genkit-tools/common/tests/utils/trace.ts +++ b/genkit-tools/common/tests/utils/trace.ts @@ -235,7 +235,6 @@ export class MockTrace { let baseFlowSpan = { ...this.BASE_FLOW_SPAN }; baseFlowSpan.attributes['genkit:input'] = JSON.stringify(flowInput); baseFlowSpan.attributes['genkit:output'] = JSON.stringify(flowOutput); - baseFlowSpan.attributes['genkit:metadata:flow:state'] = baseFlowState; let wrapperActionSpan = { ...this.WRAPPER_ACTION_SPAN }; wrapperActionSpan.attributes['genkit:input'] = JSON.stringify({ diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index e1115db3f..9d629ed14 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -932,168 +932,6 @@ "toolResponse" ], "additionalProperties": false - }, - "FlowActionInput": { - "type": "object", - "properties": { - "start": { - "type": "object", - "properties": { - "input": {}, - "labels": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - }, - "additionalProperties": false - }, - "schedule": { - "type": "object", - "properties": { - "input": {}, - "delay": { - "type": "number" - } - }, - "additionalProperties": false - }, - "runScheduled": { - "type": "object", - "properties": { - "flowId": { - "type": "string" - } - }, - "required": [ - "flowId" - ], - "additionalProperties": false - }, - "retry": { - "type": "object", - "properties": { - "flowId": { - "type": "string" - } - }, - "required": [ - "flowId" - ], - "additionalProperties": false - }, - "resume": { - "type": "object", - "properties": { - "flowId": { - "type": "string" - }, - "payload": {} - }, - "required": [ - "flowId" - ], - "additionalProperties": false - }, - "state": { - "type": "object", - "properties": { - "flowId": { - "type": "string" - } - }, - "required": [ - "flowId" - ], - "additionalProperties": false - }, - "auth": {} - }, - "additionalProperties": false - }, - "FlowError": { - "type": "object", - "properties": { - "error": { - "type": "string" - }, - "stacktrace": { - "type": "string" - } - }, - "additionalProperties": false - }, - "FlowInvokeEnvelopeMessage": { - "type": "object", - "properties": { - "start": { - "$ref": "#/$defs/FlowActionInput/properties/start" - }, - "schedule": { - "$ref": "#/$defs/FlowActionInput/properties/schedule" - }, - "runScheduled": { - "$ref": "#/$defs/FlowActionInput/properties/runScheduled" - }, - "retry": { - "$ref": "#/$defs/FlowActionInput/properties/retry" - }, - "resume": { - "$ref": "#/$defs/FlowActionInput/properties/resume" - }, - "state": { - "$ref": "#/$defs/FlowActionInput/properties/state" - } - }, - "additionalProperties": false - }, - "FlowResponse": { - "type": "object", - "properties": { - "response": { - "anyOf": [ - {}, - { - "type": "null" - } - ] - } - }, - "additionalProperties": false - }, - "FlowResult": { - "allOf": [ - { - "$ref": "#/$defs/FlowResponse" - }, - { - "$ref": "#/$defs/FlowError" - } - ] - }, - "FlowStateExecution": { - "type": "object", - "properties": { - "startTime": { - "type": "number", - "description": "start time in milliseconds since the epoch" - }, - "endTime": { - "type": "number", - "description": "end time in milliseconds since the epoch" - }, - "traceIds": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "required": [ - "traceIds" - ], - "additionalProperties": false } } } \ No newline at end of file diff --git a/genkit-tools/reflectionApi.yaml b/genkit-tools/reflectionApi.yaml index 0c373830d..683761836 100644 --- a/genkit-tools/reflectionApi.yaml +++ b/genkit-tools/reflectionApi.yaml @@ -332,6 +332,15 @@ paths: input: nullable: true description: An input with the type that this action expects. + context: + nullable: true + description: Additional runtime context data (ex. auth context data). + telemetryLabels: + type: object + nullable: true + additionalProperties: + type: string + description: Labels to be applied to telemetry data. required: - key responses: diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 2ab787225..dd7033f6d 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -25,7 +25,6 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/trace.ts', '../common/src/types/retrievers.ts', '../common/src/types/model.ts', - '../common/src/types/flow.ts', ]; /** Types that may appear that do not need to be included. */ diff --git a/go/core/action.go b/go/core/action.go index 3cbfc253a..e22c6a399 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -264,3 +264,20 @@ func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name s } return a.(*Action[In, Out, Stream]) } + +var actionContextKey = base.NewContextKey[int]() + +// WithActionContext returns a new context with action runtime context (side channel data) +// value set. +func WithActionContext(ctx context.Context, actionContext map[string]any) context.Context { + return context.WithValue(ctx, actionContextKey, actionContext) +} + +// ActionContext returns the action runtime context (side channel data) from ctx. +func ActionContext(ctx context.Context) map[string]any { + val := ctx.Value(actionContextKey) + if val == nil { + return nil + } + return val.(map[string]any) +} diff --git a/go/genkit/conformance_test.go b/go/genkit/conformance_test.go index fc10c02b3..c7c1c3c5a 100644 --- a/go/genkit/conformance_test.go +++ b/go/genkit/conformance_test.go @@ -103,7 +103,7 @@ func TestFlowConformance(t *testing.T) { r.TracingState().WriteTelemetryImmediate(tc) _ = defineFlow(r, test.Name, flowFunction(test.Commands)) key := fmt.Sprintf("/flow/%s", test.Name) - resp, err := runAction(context.Background(), r, key, test.Input, nil) + resp, err := runAction(context.Background(), r, key, test.Input, nil, nil) if err != nil { t.Fatal(err) } diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 7f62ef024..6b0e7d838 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -213,22 +213,26 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. } f.auth = flowOpts.auth metadata := map[string]any{ - "inputSchema": f.inputSchema, - "outputSchema": f.outputSchema, "requiresAuth": f.auth != nil, } - afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { + afunc := func(ctx context.Context, input In, cb func(context.Context, Stream) error) (*Out, error) { tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") - // Only non-durable flows have an auth policy so can safely assume Start.Input. - if inst.Start != nil { - if f.auth != nil { - ctx = f.auth.NewContext(ctx, inst.Auth) - } - if err := f.checkAuthPolicy(ctx, any(inst.Start.Input)); err != nil { + runtimeContext := core.ActionContext(ctx) + if f.auth != nil { + ctx = f.auth.NewContext(ctx, runtimeContext) + if err := f.checkAuthPolicy(ctx, any(input)); err != nil { return nil, err } } - return f.runInstruction(ctx, inst, streamingCallback[Stream](cb)) + var opts []FlowRunOption + if runtimeContext != nil { + opts = append(opts, WithLocalAuth(runtimeContext)) + } + result, err := f.run(ctx, input, streamingCallback[Stream](cb), opts...) + if err != nil { + return nil, err + } + return &result, err } core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc) f.tstate = r.TracingState() @@ -236,54 +240,6 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. return f } -// TODO: use flowError? - -// A flowInstruction is an instruction to follow with a flow. -// It is the input for the flow's action. -// Exactly one field will be non-nil. -type flowInstruction[In any] struct { - Start *startInstruction[In] `json:"start,omitempty"` - Resume *resumeInstruction `json:"resume,omitempty"` - Schedule *scheduleInstruction[In] `json:"schedule,omitempty"` - RunScheduled *runScheduledInstruction `json:"runScheduled,omitempty"` - State *stateInstruction `json:"state,omitempty"` - Retry *retryInstruction `json:"retry,omitempty"` - Auth map[string]any `json:"auth,omitempty"` -} - -// A startInstruction starts a flow. -type startInstruction[In any] struct { - Input In `json:"input,omitempty"` - Labels map[string]string `json:"labels,omitempty"` -} - -// A resumeInstruction resumes a flow that was started and then interrupted. -type resumeInstruction struct { - FlowID string `json:"flowId,omitempty"` - Payload any `json:"payload,omitempty"` -} - -// A scheduleInstruction schedules a flow to start at a later time. -type scheduleInstruction[In any] struct { - DelaySecs float64 `json:"delay,omitempty"` - Input In `json:"input,omitempty"` -} - -// A runScheduledInstruction starts a scheduled flow. -type runScheduledInstruction struct { - FlowID string `json:"flowId,omitempty"` -} - -// A stateInstruction retrieves the flowState from the flow. -type stateInstruction struct { - FlowID string `json:"flowId,omitempty"` -} - -// TODO: document -type retryInstruction struct { - FlowID string `json:"flowId,omitempty"` -} - // A flowState is a persistent representation of a flow that may be in the middle of running. // It contains all the information needed to resume a flow, including the original input // and a cache of all completed steps. @@ -370,30 +326,6 @@ type FlowResult[Out any] struct { StackTrace string `json:"stacktrace,omitempty"` } -// FlowResult is called FlowResponse in the javascript. - -// runInstruction performs one of several actions on a flow, as determined by msg. -// (Called runEnvelope in the js.) -func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb streamingCallback[Stream]) (*flowState[In, Out], error) { - switch { - case inst.Start != nil: - // TODO: pass msg.Start.Labels. - return f.start(ctx, inst.Start.Input, cb) - case inst.Resume != nil: - return nil, errors.ErrUnsupported - case inst.Retry != nil: - return nil, errors.ErrUnsupported - case inst.RunScheduled != nil: - return nil, errors.ErrUnsupported - case inst.Schedule != nil: - return nil, errors.ErrUnsupported - case inst.State != nil: - return nil, errors.ErrUnsupported - default: - return nil, errors.New("all known fields of FlowInvokeEnvelopeMessage are nil") - } -} - // The following methods make Flow[I, O, S] implement the flow interface, define in servers.go. // Name returns the name that the flow was defined with. diff --git a/go/genkit/servers.go b/go/genkit/servers.go index e69dab2b8..ef7cd9455 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -36,8 +36,10 @@ import ( "sync/atomic" "time" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/action" "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" @@ -45,10 +47,12 @@ import ( ) type runtimeFileData struct { - ID string `json:"id"` - PID int `json:"pid"` - ReflectionServerURL string `json:"reflectionServerUrl"` - Timestamp string `json:"timestamp"` + ID string `json:"id"` + PID int `json:"pid"` + ReflectionServerURL string `json:"reflectionServerUrl"` + Timestamp string `json:"timestamp"` + GenkitVersion string `json:"genkitVersion"` + ReflectionApiSpecVersion int `json:"reflectionApiSpecVersion"` } type devServer struct { @@ -94,10 +98,12 @@ func (s *devServer) writeRuntimeFile(url string) error { timestamp := time.Now().UTC().Format(time.RFC3339) s.runtimeFilePath = filepath.Join(runtimesDir, fmt.Sprintf("%d-%s.json", os.Getpid(), timestamp)) data := runtimeFileData{ - ID: runtimeID, - PID: os.Getpid(), - ReflectionServerURL: fmt.Sprintf("http://%s", url), - Timestamp: timestamp, + ID: runtimeID, + PID: os.Getpid(), + ReflectionServerURL: fmt.Sprintf("http://%s", url), + Timestamp: timestamp, + GenkitVersion: "go/" + internal.Version, + ReflectionApiSpecVersion: internal.GENKIT_REFLECTION_API_SPEC_VERSION, } fileContent, err := json.MarshalIndent(data, "", " ") if err != nil { @@ -241,8 +247,9 @@ func newDevServeMux(s *devServer) *http.ServeMux { func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var body struct { - Key string `json:"key"` - Input json.RawMessage `json:"input"` + Key string `json:"key"` + Input json.RawMessage `json:"input"` + Context json.RawMessage `json:"context"` } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { @@ -271,7 +278,11 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro return nil } } - resp, err := runAction(ctx, s.reg, body.Key, body.Input, callback) + var contextMap map[string]any = nil + if body.Context != nil { + json.Unmarshal(body.Context, &contextMap) + } + resp, err := runAction(ctx, s.reg, body.Key, body.Input, callback, contextMap) if err != nil { return err } @@ -281,7 +292,8 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro // handleNotify configures the telemetry server URL from the request. func (s *devServer) handleNotify(w http.ResponseWriter, r *http.Request) error { var body struct { - TelemetryServerURL string `json:"telemetryServerUrl"` + TelemetryServerURL string `json:"telemetryServerUrl"` + ReflectionApiSpecVersion int `json:"reflectionApiSpecVersion"` } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { @@ -291,6 +303,9 @@ func (s *devServer) handleNotify(w http.ResponseWriter, r *http.Request) error { s.reg.TracingState().WriteTelemetryImmediate(tracing.NewHTTPTelemetryClient(body.TelemetryServerURL)) slog.Debug("connected to telemetry server", "url", body.TelemetryServerURL) } + if body.ReflectionApiSpecVersion != internal.GENKIT_REFLECTION_API_SPEC_VERSION { + slog.Error("Genkit CLI version is not compatible with runtime library. Please use `genkit-cli` version compatible with runtime library version.") + } w.WriteHeader(http.StatusOK) _, err := w.Write([]byte("OK")) return err @@ -305,11 +320,15 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (*runActionResponse, error) { +func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := reg.LookupAction(key) if action == nil { return nil, &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no action with key %q", key)} } + if runtimeContext != nil { + ctx = core.WithActionContext(ctx, runtimeContext) + } + var traceID string output, err := tracing.RunInNewSpan(ctx, reg.TracingState(), "dev-run-action-wrapper", "", true, input, func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { tracing.SetCustomMetadataAttr(ctx, "genkit-dev-internal", "true") diff --git a/go/genkit/testdata/conformance/basic.json b/go/genkit/testdata/conformance/basic.json index 67a39b2c1..3428e8fce 100644 --- a/go/genkit/testdata/conformance/basic.json +++ b/go/genkit/testdata/conformance/basic.json @@ -5,13 +5,8 @@ {"run": {"name": "call-llm", "command": {"append": "y"}}}, {"run": {"name": "call-llm", "command": {"append": "z"}}} ], - "input": {"start": {"input": "x"}}, - "result": { - "operation": { - "done": true, - "result": {"response": "xyz"} - } - }, + "input": "x", + "result": "xyz", "trace": { "displayName": "dev-run-action-wrapper", "spans": { @@ -110,7 +105,7 @@ "genkit:type": "action", "genkit:name": "basic", "genkit:path": "/dev-run-action-wrapper/basic", - "genkit:input": "{\"start\":{\"input\":\"x\"}}", + "genkit:input": "\"x\"", "genkit:metadata:flow:wrapperAction": "true", "genkit:output": "$ANYTHING", "genkit:state": "success" diff --git a/go/genkit/testdata/conformance/run-1.json b/go/genkit/testdata/conformance/run-1.json index 5e904bd0b..a2c779574 100644 --- a/go/genkit/testdata/conformance/run-1.json +++ b/go/genkit/testdata/conformance/run-1.json @@ -1,13 +1,8 @@ { "name": "run", "commands": [{"run" :{"name": "r", "command": {"append": "x"}}}], - "input": {"start": {"input": ""}}, - "result": { - "operation": { - "done": true, - "result": {"response": "x"} - } - }, + "input": "", + "result": "x", "trace": { "displayName": "dev-run-action-wrapper", "spans": { @@ -73,7 +68,7 @@ "genkit:type": "action", "genkit:name": "run", "genkit:path": "/dev-run-action-wrapper/run", - "genkit:input": "{\"start\":{}}", + "genkit:input": "\"\"", "genkit:metadata:flow:wrapperAction": "true", "genkit:output": "$ANYTHING", "_comment": "The output above is a JSON object with various random IDs", diff --git a/go/internal/version.go b/go/internal/version.go index 6b9f6247f..d1cc79ca5 100644 --- a/go/internal/version.go +++ b/go/internal/version.go @@ -17,3 +17,5 @@ package internal // Version is the current tagged release of this module. // That is, it should match the value of the latest `go/v*` git tag. const Version = "0.2.0" + +const GENKIT_REFLECTION_API_SPEC_VERSION = 1 diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index 2c33237a6..68f119035 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -34,8 +34,10 @@ package main import ( "context" + "encoding/json" "errors" "fmt" + "log" "strconv" @@ -51,6 +53,17 @@ func main() { return genkit.Run(ctx, "call-llm", func() (string, error) { return "foo: " + foo, nil }) }) + auth := &testAuth{} + + genkit.DefineFlow("withContext", func(ctx context.Context, subject string) (string, error) { + authJson, err := json.Marshal(auth.FromContext(ctx)) + if err != nil { + return "", err + } + + return "subject=" + subject + ",auth=" + string(authJson), nil + }, genkit.WithFlowAuth(auth)) + genkit.DefineFlow("parent", func(ctx context.Context, _ struct{}) (string, error) { return basic.Run(ctx, "foo") }) @@ -92,3 +105,47 @@ func main() { log.Fatal(err) } } + +type testAuth struct { + genkit.FlowAuth +} + +const authKey = "testAuth" + +// ProvideAuthContext provides auth context from an auth header and sets it on the context. +func (f *testAuth) ProvideAuthContext(ctx context.Context, authHeader string) (context.Context, error) { + var context genkit.AuthContext + context = map[string]any{ + "username": authHeader, + } + return f.NewContext(ctx, context), nil +} + +// NewContext sets the auth context on the given context. +func (f *testAuth) NewContext(ctx context.Context, authContext genkit.AuthContext) context.Context { + return context.WithValue(ctx, authKey, authContext) +} + +// FromContext retrieves the auth context from the given context. +func (*testAuth) FromContext(ctx context.Context) genkit.AuthContext { + if ctx == nil { + return nil + } + val := ctx.Value(authKey) + if val == nil { + return nil + } + return val.(genkit.AuthContext) +} + +// CheckAuthPolicy checks auth context against policy. +func (f *testAuth) CheckAuthPolicy(ctx context.Context, input any) error { + authContext := f.FromContext(ctx) + if authContext == nil { + return errors.New("auth is required") + } + if authContext["username"] != "authorized" { + return errors.New("unauthorized") + } + return nil +} diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 92595be88..fe18083a5 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -71,6 +71,11 @@ export interface ActionRunOptions { * Additional runtime context data (ex. auth context data). */ context?: any; + + /** + * Additional span attributes to apply to OT spans. + */ + telemetryLabels?: Record; } /** @@ -127,7 +132,8 @@ type ActionParams< outputJsonSchema?: JSONSchema7; metadata?: Record; use?: Middleware, z.infer, z.infer>[]; - streamingSchema?: S; + streamSchema?: S; + actionType: ActionType; }; export type SimpleMiddleware = ( @@ -248,9 +254,18 @@ export function action< name: actionName, labels: { [SPAN_TYPE_ATTR]: 'action', + 'genkit:metadata:subtype': config.actionType, + ...options?.telemetryLabels, }, }, async (metadata, span) => { + setCustomMetadataAttributes({ subtype: config.actionType }); + if (options?.context) { + setCustomMetadataAttributes({ + context: JSON.stringify(options.context), + }); + } + traceId = span.spanContext().traceId; spanId = span.spanContext().spanId; metadata.name = actionName; @@ -317,14 +332,12 @@ export function defineAction< S extends z.ZodTypeAny = z.ZodTypeAny, >( registry: Registry, - config: ActionParams & { - actionType: ActionType; - }, + config: ActionParams, fn: ( input: z.infer, options: ActionFnArg> ) => Promise> -): Action { +): Action { if (isInRuntimeContext()) { throw new Error( 'Cannot define new actions at runtime.\n' + @@ -337,7 +350,6 @@ export function defineAction< validateActionId(config.name.actionId); } const act = action(config, async (i: I, options): Promise> => { - setCustomMetadataAttributes({ subtype: config.actionType }); await registry.initializeAllPlugins(); return await runInActionRuntimeContext(() => fn(i, options)); }); diff --git a/js/core/src/auth.ts b/js/core/src/auth.ts index 7e05df0c5..753be5153 100644 --- a/js/core/src/auth.ts +++ b/js/core/src/auth.ts @@ -17,19 +17,30 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { runInActionRuntimeContext } from './action.js'; -const authAsyncLocalStorage = new AsyncLocalStorage(); +const contextAsyncLocalStorage = new AsyncLocalStorage(); /** - * Execute the provided function in the auth context. Call {@link getFlowAuth()} anywhere - * within the async call stack to retrieve the auth. + * Execute the provided function in the runtime context. Call {@link getFlowContext()} anywhere + * within the async call stack to retrieve the context. */ -export function runWithAuthContext(auth: any, fn: () => R) { - return authAsyncLocalStorage.run(auth, () => runInActionRuntimeContext(fn)); +export function runWithContext(context: any, fn: () => R) { + return contextAsyncLocalStorage.run(context, () => + runInActionRuntimeContext(fn) + ); } /** * Gets the auth object from the current context. + * + * @deprecated use {@link getFlowContext} */ export function getFlowAuth(): any { - return authAsyncLocalStorage.getStore(); + return contextAsyncLocalStorage.getStore(); +} + +/** + * Gets the runtime context of the current flow. + */ +export function getFlowContext(): any { + return contextAsyncLocalStorage.getStore(); } diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 60d193146..2dd53d802 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { SpanStatusCode } from '@opentelemetry/api'; import * as bodyParser from 'body-parser'; import cors, { CorsOptions } from 'cors'; import express from 'express'; @@ -22,24 +21,15 @@ import { Server } from 'http'; import { z } from 'zod'; import { Action, + ActionResult, defineAction, - getStreamingCallback, StreamingCallback, } from './action.js'; -import { runWithAuthContext } from './auth.js'; +import { runWithContext } from './auth.js'; import { getErrorMessage, getErrorStack } from './error.js'; -import { FlowActionInputSchema } from './flowTypes.js'; import { logger } from './logging.js'; import { Registry } from './registry.js'; -import { toJsonSchema } from './schema.js'; -import { - newTrace, - runInNewSpan, - setCustomMetadataAttribute, - setCustomMetadataAttributes, - SPAN_TYPE_ATTR, -} from './tracing.js'; -import { flowMetadataPrefix } from './utils.js'; +import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js'; const streamDelimiter = '\n\n'; @@ -109,7 +99,7 @@ export interface CallableFlow< stream(input?: z.infer, opts?: FlowCallOptions): StreamingResponse; - flow: Flow; + flow: Flow; } /** @@ -152,18 +142,6 @@ export type FlowFn< streamingCallback: StreamingCallback> ) => Promise> | z.infer; -/** - * Represents the result of a flow execution. - */ -interface FlowResult { - /** The result of the flow execution. */ - result: O; - /** The trace ID associated with the flow execution. */ - traceId: string; - /** The root span ID of the associated trace. */ - spanId: string; -} - export class Flow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, @@ -175,12 +153,12 @@ export class Flow< readonly streamSchema?: S; readonly authPolicy?: FlowAuthPolicy; readonly middleware?: express.RequestHandler[]; - readonly flowFn: FlowFn; + readonly action: Action; constructor( private registry: Registry, config: FlowConfig | StreamingFlowConfig, - flowFn: FlowFn + action: Action ) { this.name = config.name; this.inputSchema = config.inputSchema; @@ -189,7 +167,7 @@ export class Flow< 'streamSchema' in config ? config.streamSchema : undefined; this.authPolicy = config.authPolicy; this.middleware = config.middleware; - this.flowFn = flowFn; + this.action = action; } /** @@ -200,61 +178,15 @@ export class Flow< opts: { streamingCallback?: StreamingCallback>; labels?: Record; - auth?: unknown; + context?: unknown; } - ): Promise>> { + ): Promise>> { await this.registry.initializeAllPlugins(); - return await runWithAuthContext(opts.auth, () => - newTrace( - { - name: this.name, - labels: { - [SPAN_TYPE_ATTR]: 'flow', - }, - }, - async (metadata, rootSpan) => { - if (opts.labels) { - const labels = opts.labels; - Object.keys(opts.labels).forEach((label) => { - setCustomMetadataAttribute( - flowMetadataPrefix(`label:${label}`), - labels[label] - ); - }); - } - - setCustomMetadataAttributes({ - [flowMetadataPrefix('name')]: this.name, - }); - try { - metadata.input = input; - const output = await this.flowFn( - input, - opts.streamingCallback ?? (() => {}) - ); - metadata.output = JSON.stringify(output); - setCustomMetadataAttribute(flowMetadataPrefix('state'), 'done'); - return { - result: output, - traceId: rootSpan.spanContext().traceId, - spanId: rootSpan.spanContext().spanId, - }; - } catch (e) { - metadata.state = 'error'; - rootSpan.setStatus({ - code: SpanStatusCode.ERROR, - message: getErrorMessage(e), - }); - if (e instanceof Error) { - rootSpan.recordException(e); - } - - setCustomMetadataAttribute(flowMetadataPrefix('state'), 'error'); - throw e; - } - } - ) - ); + return await this.action.run(input, { + context: opts.context, + telemetryLabels: opts.labels, + onChunk: opts.streamingCallback ?? (() => {}), + }); } /** @@ -271,7 +203,7 @@ export class Flow< } const result = await this.invoke(input, { - auth: opts?.context || opts?.withLocalAuthContext, + context: opts?.context || opts?.withLocalAuthContext, }); return result.result; } @@ -306,7 +238,7 @@ export class Flow< }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - auth: opts?.context || opts?.withLocalAuthContext, + context: opts?.context || opts?.withLocalAuthContext, } ).then((s) => s.result) ) @@ -366,7 +298,7 @@ export class Flow< 'data: ' + JSON.stringify({ message: chunk }) + streamDelimiter ); }, - auth, + context: auth, }); response.write( 'data: ' + JSON.stringify({ result: result.result }) + streamDelimiter @@ -389,9 +321,9 @@ export class Flow< } } else { try { - const result = await this.invoke(input, { auth }); - response.setHeader('x-genkit-trace-id', result.traceId); - response.setHeader('x-genkit-span-id', result.spanId); + const result = await this.invoke(input, { context: auth }); + response.setHeader('x-genkit-trace-id', result.telemetry.traceId); + response.setHeader('x-genkit-span-id', result.telemetry.spanId); // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." response .status(200) @@ -538,18 +470,15 @@ export function defineFlow< S extends z.ZodTypeAny = z.ZodTypeAny, >( registry: Registry, - config: StreamingFlowConfig | string, + config: StreamingFlowConfig | string, fn: FlowFn ): CallableFlow { - const resolvedConfig: FlowConfig = + const resolvedConfig: StreamingFlowConfig = typeof config === 'string' ? { name: config } : config; - const flow = new Flow(registry, resolvedConfig, fn); - registerFlowAction(registry, flow); - const callableFlow = async ( - input: z.infer, - opts: FlowCallOptions - ): Promise> => { + const flowAction = defineFlowAction(registry, resolvedConfig, fn); + const flow = new Flow(registry, resolvedConfig, flowAction); + const callableFlow = async (input, opts) => { return flow.run(input, opts); }; (callableFlow as CallableFlow).flow = flow; @@ -574,8 +503,8 @@ export function defineStreamingFlow< config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = new Flow(registry, config, fn); - registerFlowAction(registry, flow); + const flowAction = defineFlowAction(registry, config, fn); + const flow = new Flow(registry, config, flowAction); const streamableFlow: StreamableFlow = (input, opts) => { return flow.stream(input, opts); }; @@ -586,41 +515,30 @@ export function defineStreamingFlow< /** * Registers a flow as an action in the registry. */ -function registerFlowAction< +function defineFlowAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( registry: Registry, - flow: Flow -): Action { + config: StreamingFlowConfig, + fn: FlowFn +): Action { return defineAction( registry, { actionType: 'flow', - name: flow.name, - inputSchema: FlowActionInputSchema, - outputSchema: flow.outputSchema, + name: config.name, + inputSchema: config.inputSchema, + outputSchema: config.outputSchema, + streamSchema: config.streamSchema, metadata: { - inputSchema: toJsonSchema({ schema: flow.inputSchema }), - outputSchema: toJsonSchema({ schema: flow.outputSchema }), - requiresAuth: !!flow.authPolicy, + requiresAuth: !!config.authPolicy, }, }, - async (envelope) => { - await flow.authPolicy?.( - envelope.auth, - envelope.start?.input as I | undefined - ); - setCustomMetadataAttribute(flowMetadataPrefix('wrapperAction'), 'true'); - const response = await flow.invoke(envelope.start?.input, { - streamingCallback: getStreamingCallback() as S extends z.ZodVoid - ? undefined - : StreamingCallback>, - auth: envelope.auth, - labels: envelope.start?.labels, - }); - return response.result; + async (input, { sendChunk, context }) => { + await config.authPolicy?.(context, input); + return await runWithContext(context, () => fn(input, sendChunk)); } ); } diff --git a/js/core/src/flowTypes.ts b/js/core/src/flowTypes.ts deleted file mode 100644 index bb56bf530..000000000 --- a/js/core/src/flowTypes.ts +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { z } from 'zod'; - -// NOTE: Keep this file in sync with genkit-tools/src/types/flow.ts! -// Eventually tools will be source of truth for these types (by generating a -// JSON schema) but until then this file must be manually kept in sync - -export const FlowResponseSchema = z.object({ - response: z.unknown().nullable(), -}); - -export const FlowErrorSchema = z.object({ - error: z.string().optional(), - stacktrace: z.string().optional(), -}); - -export type FlowError = z.infer; - -export const FlowResultSchema = FlowResponseSchema.and(FlowErrorSchema); - -/** - * Used for flow control. - */ -export const FlowInvokeEnvelopeMessageSchema = z.object({ - // Start new flow. - start: z.object({ - input: z.unknown().optional(), - labels: z.record(z.string(), z.string()).optional(), - }), -}); - -export type FlowInvokeEnvelopeMessage = z.infer< - typeof FlowInvokeEnvelopeMessageSchema ->; - -/** - * Used by the flow action. - */ -export const FlowActionInputSchema = FlowInvokeEnvelopeMessageSchema.extend({ - auth: z.unknown().optional(), -}); - -export type FlowActionInput = z.infer; diff --git a/js/core/src/index.ts b/js/core/src/index.ts index 3bd1c1ad3..f8d4ce85a 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -18,6 +18,7 @@ import { version } from './__codegen/version.js'; export const GENKIT_VERSION = version; export const GENKIT_CLIENT_HEADER = `genkit-node/${GENKIT_VERSION} gl-node/${process.versions.node}`; +export const GENKIT_REFLECTION_API_SPEC_VERSION = 1; export { z } from 'zod'; export * from './action.js'; @@ -38,7 +39,6 @@ export { type StreamingFlowConfig, type __RequestWithAuth, } from './flow.js'; -export * from './flowTypes.js'; export * from './plugin.js'; export * from './reflection.js'; export { defineJsonSchema, defineSchema, type JSONSchema } from './schema.js'; diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 4ebf30a69..dae0054e1 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -21,7 +21,7 @@ import { Server } from 'http'; import path from 'path'; import z from 'zod'; import { Status, StatusCodes, runWithStreamingCallback } from './action.js'; -import { GENKIT_VERSION } from './index.js'; +import { GENKIT_REFLECTION_API_SPEC_VERSION, GENKIT_VERSION } from './index.js'; import { logger } from './logging.js'; import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; @@ -153,7 +153,7 @@ export class ReflectionServer { }); server.post('/api/runAction', async (request, response, next) => { - const { key, input } = request.body; + const { key, input, context, telemetryLabels } = request.body; const { stream } = request.query; logger.debug(`Running action \`${key}\` with stream=${stream}...`); let traceId; @@ -164,11 +164,12 @@ export class ReflectionServer { return; } if (stream === 'true') { + const callback = (chunk) => { + response.write(JSON.stringify(chunk) + '\n'); + }; const result = await runWithStreamingCallback( - (chunk) => { - response.write(JSON.stringify(chunk) + '\n'); - }, - async () => await action.run(input) + callback, + async () => await action.run(input, { context, onChunk: callback }) ); await flushTracing(); response.write( @@ -181,7 +182,7 @@ export class ReflectionServer { ); response.end(); } else { - const result = await action.run(input); + const result = await action.run(input, { context, telemetryLabels }); await flushTracing(); response.send({ result: result.result, @@ -201,11 +202,27 @@ export class ReflectionServer { }); server.post('/api/notify', async (request, response) => { - const { telemetryServerUrl } = request.body; + const { telemetryServerUrl, reflectionApiSpecVersion } = request.body; if (typeof telemetryServerUrl === 'string') { setTelemetryServerUrl(telemetryServerUrl); logger.debug(`Connected to telemetry server on ${telemetryServerUrl}`); } + if (reflectionApiSpecVersion !== GENKIT_REFLECTION_API_SPEC_VERSION) { + if ( + !reflectionApiSpecVersion || + reflectionApiSpecVersion < GENKIT_REFLECTION_API_SPEC_VERSION + ) { + logger.warn( + 'WARNING: Genkit CLI version may be outdated. Please update `genkit-cli` to the latest version.' + ); + } else { + logger.warn( + 'Genkit CLI is newer than runtime library. Some feature may not be supported. ' + + 'Consider upgrading your runtime library version (debug info: expected ' + + `${GENKIT_REFLECTION_API_SPEC_VERSION}, got ${reflectionApiSpecVersion}).` + ); + } + } response.status(200).send('OK'); }); @@ -286,6 +303,8 @@ export class ReflectionServer { pid: process.pid, reflectionServerUrl: `http://localhost:${this.port}`, timestamp, + genkitVersion: `nodejs/${GENKIT_VERSION}`, + reflectionApiSpecVersion: GENKIT_REFLECTION_API_SPEC_VERSION, }, null, 2 diff --git a/js/core/src/tracing/instrumentation.ts b/js/core/src/tracing/instrumentation.ts index 2203ef1f2..0dd0173cb 100644 --- a/js/core/src/tracing/instrumentation.ts +++ b/js/core/src/tracing/instrumentation.ts @@ -230,7 +230,9 @@ function buildPath( labels?: Record ) { const stepType = - labels && labels['genkit:type'] ? `,t:${labels['genkit:type']}` : ''; + labels && labels['genkit:type'] + ? `,t:${labels['genkit:metadata:subtype'] === 'flow' ? 'flow' : labels['genkit:type']}` + : ''; return parentPath + `/{${name}${stepType}}`; } diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 5a3df6e7f..cecbab523 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -17,6 +17,7 @@ import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; +import { getFlowContext } from '../src/auth.js'; import { defineFlow, defineStreamingFlow, run } from '../src/flow.js'; import { defineAction, getFlowAuth, z } from '../src/index.js'; import { Registry } from '../src/registry.js'; @@ -42,40 +43,6 @@ function createTestFlow(registry: Registry) { ); } -function createTestAuthFlow(registry: Registry) { - return defineFlow( - registry, - { - name: 'testFlow', - inputSchema: z.string(), - outputSchema: z.string(), - }, - async (input) => { - return `bar ${input} ${JSON.stringify(getFlowAuth())}`; - } - ); -} - -function createTestAuthStreamingFlow(registry: Registry) { - return defineStreamingFlow( - registry, - { - name: 'testFlow', - inputSchema: z.number(), - outputSchema: z.string(), - streamSchema: z.object({ count: z.number() }), - }, - async (input, streamingCallback) => { - if (streamingCallback) { - for (let i = 0; i < input; i++) { - streamingCallback({ count: i }); - } - } - return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getFlowAuth())}`; - } - ); -} - function createTestStreamingFlow(registry: Registry) { return defineStreamingFlow( registry, @@ -207,7 +174,17 @@ describe('flow', () => { describe('getFlowAuth', () => { it('should run the flow', async () => { - const testFlow = createTestAuthFlow(registry); + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + return `bar ${input} ${JSON.stringify(getFlowAuth())}`; + } + ); const response = await testFlow('foo', { withLocalAuthContext: { user: 'test-user' }, @@ -217,7 +194,23 @@ describe('flow', () => { }); it('should streams the flow', async () => { - const testFlow = createTestAuthStreamingFlow(registry); + const testFlow = defineStreamingFlow( + registry, + { + name: 'testFlow', + inputSchema: z.number(), + outputSchema: z.string(), + streamSchema: z.object({ count: z.number() }), + }, + async (input, streamingCallback) => { + if (streamingCallback) { + for (let i = 0; i < input; i++) { + streamingCallback({ count: i }); + } + } + return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getFlowAuth())}`; + } + ); const response = testFlow(3, { withLocalAuthContext: { user: 'test-user' }, @@ -233,6 +226,60 @@ describe('flow', () => { }); }); + describe('getFlowContext', () => { + it('should run the flow', async () => { + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + return `bar ${input} ${JSON.stringify(getFlowContext())}`; + } + ); + + const response = await testFlow('foo', { + context: { user: 'test-user' }, + }); + + assert.equal(response, 'bar foo {"user":"test-user"}'); + }); + + it('should streams the flow', async () => { + const testFlow = defineStreamingFlow( + registry, + { + name: 'testFlow', + inputSchema: z.number(), + outputSchema: z.string(), + streamSchema: z.object({ count: z.number() }), + }, + async (input, streamingCallback) => { + if (streamingCallback) { + for (let i = 0; i < input; i++) { + streamingCallback({ count: i }); + } + } + return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getFlowContext())}`; + } + ); + + const response = testFlow(3, { + context: { user: 'test-user' }, + }); + + const gotChunks: any[] = []; + for await (const chunk of response.stream) { + gotChunks.push(chunk); + } + + assert.equal(await response.output, 'bar 3 true {"user":"test-user"}'); + assert.deepEqual(gotChunks, [{ count: 0 }, { count: 1 }, { count: 2 }]); + }); + }); + describe('telemetry', async () => { beforeEach(() => { spanExporter.exportedSpans = []; @@ -249,13 +296,12 @@ describe('flow', () => { assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, { 'genkit:input': '"foo"', 'genkit:isRoot': true, - 'genkit:metadata:flow:name': 'testFlow', - 'genkit:metadata:flow:state': 'done', + 'genkit:metadata:subtype': 'flow', 'genkit:name': 'testFlow', 'genkit:output': '"bar foo"', 'genkit:path': '/{testFlow,t:flow}', 'genkit:state': 'success', - 'genkit:type': 'flow', + 'genkit:type': 'action', }); }); @@ -285,7 +331,7 @@ describe('flow', () => { }); } ); - const result = await testFlow('foo'); + const result = await testFlow('foo', { context: { user: 'pavel' } }); assert.equal(result, 'foo bar'); assert.strictEqual(spanExporter.exportedSpans.length, 3); @@ -317,13 +363,13 @@ describe('flow', () => { assert.deepStrictEqual(spanExporter.exportedSpans[2].attributes, { 'genkit:input': '"foo"', 'genkit:isRoot': true, - 'genkit:metadata:flow:name': 'testFlow', - 'genkit:metadata:flow:state': 'done', + 'genkit:metadata:subtype': 'flow', + 'genkit:metadata:context': '{"user":"pavel"}', 'genkit:name': 'testFlow', 'genkit:output': '"foo bar"', 'genkit:path': '/{testFlow,t:flow}', 'genkit:state': 'success', - 'genkit:type': 'flow', + 'genkit:type': 'action', }); }); }); diff --git a/js/genkit/src/index.ts b/js/genkit/src/index.ts index bc7dd3e0d..33b7eb36a 100644 --- a/js/genkit/src/index.ts +++ b/js/genkit/src/index.ts @@ -101,9 +101,6 @@ export { } from '@genkit-ai/ai'; export { type SessionData, type SessionStore } from '@genkit-ai/ai/session'; export { - FlowActionInputSchema, - FlowErrorSchema, - FlowInvokeEnvelopeMessageSchema, FlowServer, GENKIT_CLIENT_HEADER, GENKIT_VERSION, @@ -129,14 +126,9 @@ export { type ActionMetadata, type CallableFlow, type Flow, - type FlowActionInput, type FlowAuthPolicy, type FlowConfig, - type FlowError, type FlowFn, - type FlowInvokeEnvelopeMessage, - type FlowResponseSchema, - type FlowResultSchema, type FlowServerOptions, type JSONSchema, type JSONSchema7, diff --git a/js/plugins/google-cloud/tests/logs_no_io_test.ts b/js/plugins/google-cloud/tests/logs_no_io_test.ts index a7111a610..8025580c1 100644 --- a/js/plugins/google-cloud/tests/logs_no_io_test.ts +++ b/js/plugins/google-cloud/tests/logs_no_io_test.ts @@ -219,7 +219,7 @@ function createFlowWithInput( { name, inputSchema: z.string(), - outputSchema: z.string(), + outputSchema: z.any(), }, fn ); diff --git a/js/plugins/google-cloud/tests/logs_test.ts b/js/plugins/google-cloud/tests/logs_test.ts index ae4e45d04..01ffa9c0b 100644 --- a/js/plugins/google-cloud/tests/logs_test.ts +++ b/js/plugins/google-cloud/tests/logs_test.ts @@ -272,7 +272,7 @@ function createFlowWithInput( { name, inputSchema: z.string(), - outputSchema: z.string(), + outputSchema: z.any(), }, fn ); diff --git a/js/plugins/google-cloud/tests/traces_test.ts b/js/plugins/google-cloud/tests/traces_test.ts index ffa401c16..fefaa4c37 100644 --- a/js/plugins/google-cloud/tests/traces_test.ts +++ b/js/plugins/google-cloud/tests/traces_test.ts @@ -91,7 +91,8 @@ describe('GoogleCloudTracing', () => { const spans = await getExportedSpans(); // Check some common attributes assert.equal(spans[0].attributes['genkit/name'], 'testFlow'); - assert.equal(spans[0].attributes['genkit/type'], 'flow'); + assert.equal(spans[0].attributes['genkit/type'], 'action'); + assert.equal(spans[0].attributes['genkit/metadata/subtype'], 'flow'); // Ensure we have no attributes with ':' because these are awkward to use in // Cloud Trace. const spanAttrKeys = Object.entries(spans[0].attributes).map(([k, v]) => k); diff --git a/js/testapps/flow-sample1/src/index.ts b/js/testapps/flow-sample1/src/index.ts index 045bb224b..fcaeb84a4 100644 --- a/js/testapps/flow-sample1/src/index.ts +++ b/js/testapps/flow-sample1/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { genkit, run, z } from 'genkit'; +import { genkit, getFlowAuth, run, z } from 'genkit'; const ai = genkit({}); @@ -22,7 +22,7 @@ const ai = genkit({}); * To run this flow; * genkit flow:run basic "\"hello\"" */ -export const basic = ai.defineFlow({ name: 'basic' }, async (subject) => { +export const basic = ai.defineFlow('basic', async (subject) => { const foo = await run('call-llm', async () => { return `subject: ${subject}`; }); @@ -39,6 +39,30 @@ export const parent = ai.defineFlow( } ); +export const withInputSchema = ai.defineFlow( + { name: 'withInputSchema', inputSchema: z.object({ subject: z.string() }) }, + async (input) => { + const foo = await run('call-llm', async () => { + return `subject: ${input.subject}`; + }); + + return await run('call-llm1', async () => { + return `foo: ${foo}`; + }); + } +); + +export const withContext = ai.defineFlow( + { + name: 'withContext', + inputSchema: z.object({ subject: z.string() }), + authPolicy: () => {}, + }, + async (input) => { + return `subject: ${input.subject}, context: ${JSON.stringify(getFlowAuth())}`; + } +); + // genkit flow:run streamy 5 -s export const streamy = ai.defineStreamingFlow( { diff --git a/tests/reflection_api_tests.yaml b/tests/reflection_api_tests.yaml index 0f68a54b1..060ce3302 100644 --- a/tests/reflection_api_tests.yaml +++ b/tests/reflection_api_tests.yaml @@ -876,8 +876,3 @@ tests: /flow/testFlow: key: /flow/testFlow name: testFlow - metadata: - inputSchema: - type: string - outputSchema: - type: string