Skip to content

Commit

Permalink
Merge pull request #19 from parea-ai/test_collection_in_experiment
Browse files Browse the repository at this point in the history
feat(experiment): use test collection in experiment
  • Loading branch information
jalexanderII authored Feb 7, 2024
2 parents 8efb9df + 242a21c commit 5eb5d3c
Show file tree
Hide file tree
Showing 8 changed files with 1,949 additions and 1,814 deletions.
15 changes: 14 additions & 1 deletion src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ExperimentSchema,
ExperimentStatsSchema,
FeedbackRequest,
TestCaseCollection,
UseDeployedPrompt,
UseDeployedPromptResponse,
} from './types';
Expand All @@ -23,6 +24,7 @@ const RECORD_FEEDBACK_ENDPOINT = '/feedback';
const EXPERIMENT_ENDPOINT = '/experiment';
const EXPERIMENT_STATS_ENDPOINT = '/experiment/{experiment_uuid}/stats';
const EXPERIMENT_FINISHED_ENDPOINT = '/experiment/{experiment_uuid}/finished';
const GET_COLLECTION_ENDPOINT = '/collection/{test_collection_name}';

export class Parea {
private apiKey: string;
Expand Down Expand Up @@ -112,11 +114,22 @@ export class Parea {
return response.data;
}

public async getCollection(testCollectionName: string): Promise<TestCaseCollection> {
const response = await this.client.request({
method: 'GET',
endpoint: GET_COLLECTION_ENDPOINT.replace('{test_collection_name}', testCollectionName),
});
return response.data;
}

public experiment(
data: Iterable<DataItem>,
data: string | Iterable<DataItem>,
func: (...dataItem: any[]) => Promise<any>,
name: string = '',
): Experiment {
if (typeof data === 'string') {
return new Experiment(data, func, name, this);
}
const convertedData: Iterable<any[]> = Array.from(data).map((item) => Object.values(item));
return new Experiment(convertedData, func, name, this);
}
Expand Down
49 changes: 49 additions & 0 deletions src/cookbook/experiment_using_saved_test_collection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { Parea } from '../client';
import { Completion, CompletionResponse, Log, Message } from '../types';
import { trace } from '../utils/trace_utils';
import * as dotenv from 'dotenv';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);

function evalFunc(log: Log): number {
if (log.inputs?.['x']?.toLowerCase() === 'python') {
return 1.0;
}
return Math.random();
}

const callLLM = async (data: Message[]): Promise<CompletionResponse> => {
const completion: Completion = {
llm_configuration: {
model: 'gpt-4',
model_params: { temp: 1.0 },
messages: data,
},
};
return await p.completion(completion);
};

const helloWorld = trace(
'helloWorld',
async (x: string, y: string): Promise<string> => {
const response = await callLLM([{ role: 'user', content: `Write a hello world program in ${x} using ${y}` }]);
return response.content;
},
{
evalFuncs: [evalFunc],
},
);

export async function main() {
const e = p.experiment(
'Hello World Example', // this is the name of my Test Collection in Parea (TestHub page)
helloWorld,
);
return await e.run();
}

main().then(() => {
console.log('Experiment complete!');
});
32 changes: 24 additions & 8 deletions src/experiment/experiment.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ExperimentStatsSchema, TraceStatsSchema } from '../types';
import { ExperimentStatsSchema, TestCaseCollection, TraceStatsSchema } from '../types';
import { Parea } from '../client';
import { asyncPool } from '../helpers';
import { genRandomName } from './genRandomName';
import { genRandomName } from './utils';

function calculateAvgAsString(values: number[] | undefined): string {
if (!values || values.length === 0) {
Expand Down Expand Up @@ -47,11 +47,26 @@ function calculateAvgStdForExperiment(experimentStats: ExperimentStatsSchema): {

async function experiment(
name: string,
data: Iterable<any[]>,
data: string | Iterable<any[]>,
func: (...dataItem: any[]) => Promise<any>,
p: Parea,
maxParallelCalls: number = 10,
): Promise<ExperimentStatsSchema> {
if (typeof data === 'string') {
console.log(`Fetching test collection: ${data}`);
const response = await p.getCollection(data);
const testCollection = new TestCaseCollection(
response.id,
response.name,
response.created_at,
response.last_updated_at,
response.column_names,
response.test_cases,
);
console.log(`Fetched ${testCollection.numTestCases()} test cases from collection: ${data} \n`);
data = testCollection.getAllTestCaseInputs();
}

const experimentSchema = await p.createExperiment({ name });
const experimentUUID = experimentSchema.uuid;
process.env.PAREA_OS_ENV_EXPERIMENT_UUID = experimentUUID;
Expand All @@ -67,26 +82,27 @@ async function experiment(

const experimentStats: ExperimentStatsSchema = await p.finishExperiment(experimentUUID);
const statNameToAvgStd = calculateAvgStdForExperiment(experimentStats);
console.log(`Experiment stats:\n${JSON.stringify(statNameToAvgStd, null, 2)}\n\n`);
console.log(`Experiment ${name} stats:\n${JSON.stringify(statNameToAvgStd, null, 2)}\n\n`);
console.log(`View experiment & its traces at: https://app.parea.ai/experiments/${experimentUUID}\n`);
return experimentStats;
}

export class Experiment {
name: string;
data: Iterable<any[]>;
data: string | Iterable<any[]>;
func: (...dataItem: any[]) => Promise<any>;
p: Parea;
experimentStats?: ExperimentStatsSchema;

constructor(data: Iterable<any[]>, func: (...dataItem: any[]) => Promise<any>, name: string, p: Parea) {
this.name = name || genRandomName();
constructor(data: string | Iterable<any[]>, func: (...dataItem: any[]) => Promise<any>, name: string, p: Parea) {
this.name = name;
this.data = data;
this.func = func;
this.p = p;
}

async run(): Promise<void> {
async run(name: string | undefined = undefined): Promise<void> {
this.name = name || genRandomName();
this.experimentStats = new ExperimentStatsSchema(
(await experiment(this.name, this.data, this.func, this.p)).parent_trace_stats,
);
Expand Down
Loading

0 comments on commit 5eb5d3c

Please sign in to comment.