Skip to content

Commit

Permalink
Merge pull request #27 from parea-ai/PAI-728-image-capabilities-in-tr…
Browse files Browse the repository at this point in the history
…ace-sdkt

feat(images): handle images
  • Loading branch information
jalexanderII authored Mar 6, 2024
2 parents d11b903 + 5a15607 commit b1b8ca6
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 40 deletions.
68 changes: 40 additions & 28 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {

import { HTTPClient } from './api-client';
import { pareaLogger } from './parea_logger';
import { genTraceId } from './helpers';
import { genTraceId, serializeMetadataValues } from './helpers';
import { asyncLocalStorage } from './utils/trace_utils';
import { pareaProject } from './project';
import { Experiment } from './experiment/experiment';
Expand Down Expand Up @@ -49,39 +49,14 @@ export class Parea {
}

public async completion(data: Completion): Promise<CompletionResponse> {
let experiment_uuid;
const parentStore = asyncLocalStorage.getStore();
const parentTraceId = parentStore ? Array.from(parentStore.keys())[0] : undefined; // Assuming the last traceId is the parent

const inference_id = genTraceId();
data.inference_id = inference_id;
data.parent_trace_id = parentTraceId || inference_id;
data.root_trace_id = parentStore ? Array.from(parentStore.values())[0].rootTraceId : data.parent_trace_id;

if (process.env.PAREA_OS_ENV_EXPERIMENT_UUID) {
experiment_uuid = process.env.PAREA_OS_ENV_EXPERIMENT_UUID;
data.experiment_uuid = experiment_uuid;
}
const requestData = await this.updateDataAndTrace(data);

const response = await this.client.request({
method: 'POST',
endpoint: COMPLETION_ENDPOINT,
data: {
project_uuid: await pareaProject.getProjectUUID(),
...data,
},
data: requestData,
});

if (parentStore && parentTraceId) {
const parentTraceLog = parentStore.get(parentTraceId);
if (parentTraceLog) {
parentTraceLog.traceLog.children.push(inference_id);
parentTraceLog.traceLog.experiment_uuid = experiment_uuid;
parentStore.set(parentTraceId, parentTraceLog);
await pareaLogger.recordLog(parentTraceLog.traceLog);
}
}

return response.data;
}

Expand Down Expand Up @@ -171,4 +146,41 @@ export class Parea {
}
return new Experiment(data, func, '', this, options?.metadata, options?.datasetLevelEvalFuncs, options?.nWorkers);
}

private async updateDataAndTrace(data: Completion): Promise<Completion> {
// @ts-ignore
data = serializeMetadataValues(data);

let experiment_uuid;
const inference_id = genTraceId();
data.inference_id = inference_id;
data.project_uuid = await pareaProject.getProjectUUID();

try {
const parentStore = asyncLocalStorage.getStore();
const parentTraceId = parentStore ? Array.from(parentStore.keys())[0] : undefined; // Assuming the last traceId is the parent

data.parent_trace_id = parentTraceId || inference_id;
data.root_trace_id = parentStore ? Array.from(parentStore.values())[0].rootTraceId : data.parent_trace_id;

if (process.env.PAREA_OS_ENV_EXPERIMENT_UUID) {
experiment_uuid = process.env.PAREA_OS_ENV_EXPERIMENT_UUID;
data.experiment_uuid = experiment_uuid;
}

if (parentStore && parentTraceId) {
const parentTraceLog = parentStore.get(parentTraceId);
if (parentTraceLog) {
parentTraceLog.traceLog.children.push(inference_id);
parentTraceLog.traceLog.experiment_uuid = experiment_uuid;
parentStore.set(parentTraceId, parentTraceLog);
await pareaLogger.recordLog(parentTraceLog.traceLog);
}
}
} catch (e) {
console.debug(`Error updating trace ids for completion. Trace log will be absent: ${e}`);
}

return data;
}
}
8 changes: 4 additions & 4 deletions src/cookbook/tracing_with_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ dotenv.config();
const p = new Parea(process.env.PAREA_API_KEY);

const LLM_OPTIONS = [
['gpt-3.5-turbo', 'openai'],
['gpt-4', 'openai'],
['claude-instant-1', 'anthropic'],
['claude-2', 'anthropic'],
['gpt-3.5-turbo-0125', 'openai'],
['gpt-4-0125-preview', 'openai'],
['claude-3-sonnet-20240229', 'anthropic'],
['claude-2.1', 'anthropic'],
];
const LIMIT = 1;

Expand Down
52 changes: 52 additions & 0 deletions src/cookbook/tracing_with_images_open_ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import * as dotenv from 'dotenv';
import { trace, traceInsert } from '../utils/trace_utils';
import OpenAI from 'openai';
import { patchOpenAI } from '../utils/wrap_openai';
import { Parea } from '../client';

dotenv.config();

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

// needed for tracing
new Parea(process.env.PAREA_API_KEY);

// Patch OpenAI to add trace logs
patchOpenAI(openai);

const imageMaker = trace('imageMaker', async (query: string): Promise<string | undefined> => {
const response = await openai.images.generate({ prompt: query, model: 'dall-e-3' });
const image_url = response.data[0].url;
const caption = { original_prompt: query, revised_prompt: response.data[0].revised_prompt };
traceInsert({ images: [{ url: image_url, caption: JSON.stringify(caption) }] });
return image_url;
});

const askVision = trace('askVision', async (image_url: string): Promise<string | null> => {
const response = await openai.chat.completions.create({
model: 'gpt-4-vision-preview',
messages: [
{
role: 'user',
content: [
{ type: 'text', text: 'What’s in this image?' },
{ type: 'image_url', image_url: { url: image_url } },
],
},
],
});
return response.choices[0].message.content;
});

const imageChain = trace('imageChain', async (query: string) => {
const image_url = await imageMaker(query);
return await askVision(image_url);
});

async function main(query: string) {
return await imageChain(query);
}

main('A dog sitting comfortably on a bed').then((result) => console.log(result));
36 changes: 36 additions & 0 deletions src/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { v4 as uuidv4 } from 'uuid';
import { Completion, Log, TraceLog, UpdateLog } from './types';

export function genTraceId(): string {
// Generate a unique trace id for each chain of requests
Expand Down Expand Up @@ -49,6 +50,41 @@ export async function* asyncPool<T, R>(
}
}

export type LogData = Completion & TraceLog & Log;

export function serializeMetadataValues(logData: LogData): LogData {
if (logData?.metadata) {
logData.metadata = serializeValues(logData?.metadata);
}

// Support openai vision content format
if (logData?.configuration) {
logData?.configuration?.messages?.forEach((message) => {
// noinspection SuspiciousTypeOfGuard
if (typeof message.content !== 'string') {
message.content = JSON.stringify(message.content);
}
});
}

return logData;
}

export function serializeValues(metadata: { [key: string]: any }): { [key: string]: string } {
const serialized: { [key: string]: string } = {};
for (const [key, value] of Object.entries(metadata)) {
serialized[key] = typeof value === 'string' ? value : JSON.stringify(value);
}
return serialized;
}

export function serializeMetadataValuesUpdate(logData: UpdateLog): UpdateLog {
if (logData?.field_name_to_value_map?.metadata) {
logData.field_name_to_value_map.metadata = serializeValues(logData.field_name_to_value_map.metadata);
}
return logData;
}

export const NOUNS: string[] = [
'abac',
'abbs',
Expand Down
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export {
TraceLogInputs,
Log,
TraceLog,
TraceLogImage,
TraceLogTreeSchema,
EvaluationResult,
TraceOptions,
Expand Down
10 changes: 8 additions & 2 deletions src/parea_logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { LangchainRunCreate, TraceIntegrations, TraceLog, UpdateLog } from './ty
import { AxiosResponse } from 'axios';
import { HTTPClient } from './api-client';
import { pareaProject } from './project';
import { serializeMetadataValues, serializeMetadataValuesUpdate } from './helpers';

const LOG_ENDPOINT = '/trace_log';
const VENDOR_LOG_ENDPOINT = '/trace_log/{vendor}';
Expand All @@ -18,10 +19,11 @@ export class PareaLogger {
}

public async recordLog(data: TraceLog): Promise<AxiosResponse<any>> {
const log = { ...data, project_uuid: await pareaProject.getProjectUUID() };
return await this.client.request({
method: 'POST',
endpoint: LOG_ENDPOINT,
data: { ...data, project_uuid: await pareaProject.getProjectUUID() },
data: serializeMetadataValues(log),
});
}

Expand All @@ -34,7 +36,11 @@ export class PareaLogger {
}

public async updateLog(data: UpdateLog): Promise<AxiosResponse<any>> {
return await this.client.request({ method: 'PUT', endpoint: LOG_ENDPOINT, data });
return await this.client.request({
method: 'PUT',
endpoint: LOG_ENDPOINT,
data: serializeMetadataValuesUpdate(data),
});
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export type Completion = {
log_omit_outputs?: boolean;
log_omit?: boolean;
experiment_uuid?: string | null;
project_uuid?: string;
};

export type CompletionResponse = {
Expand Down Expand Up @@ -130,6 +131,11 @@ export type EvaluatedLog = Log & {
scores?: EvaluationResult[];
};

export type TraceLogImage = {
url: string;
caption?: string;
};

export type TraceLog = EvaluatedLog & {
trace_id: string;
parent_trace_id?: string;
Expand All @@ -151,6 +157,7 @@ export type TraceLog = EvaluatedLog & {
metadata?: { [key: string]: any };
tags?: string[];
experiment_uuid?: string | null;
images?: TraceLogImage[];
};

export type TraceLogTreeSchema = TraceLog & {
Expand Down
8 changes: 2 additions & 6 deletions src/utils/trace_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ const merge = (old: any, newValue: any) => {

/**
* Insert data into the trace log for the current or specified trace id. Data should be a dictionary with keys that correspond to the fields of the TraceLog model.
* If the field already has an existing value that is extensible (dict, set, list, etc.), the new value will be merged with the existing value.
*
* @param data = list of key-value pairs where keys represent input names.
* Each item in the list represent a test case row.
* Target and Tags are reserved keys. There can only be one target and tags key per dict item.
* If target is present it will represent the target/expected response for the inputs.
* If tags are present they must be a list of json_serializable values.
* @param data - Keys can be one of: trace_name, end_user_identifier, metadata, tags, deployment_id
* @param data - Keys can be one of: trace_name, end_user_identifier, metadata, tags, deployment_id, images
* @param traceId - The trace id to insert the data into. If not provided, the current trace id will be used.
* @returns void
*/
Expand Down

0 comments on commit b1b8ca6

Please sign in to comment.