Skip to content

Commit

Permalink
Merge pull request #23 from parea-ai/PAI-737-enable-dataset-level-eva…
Browse files Browse the repository at this point in the history
…luation-metrics

Pai 737 enable dataset level evaluation metrics
  • Loading branch information
joschkabraun authored Feb 27, 2024
2 parents 796bbf1 + 830076a commit 074f4c6
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 24 deletions.
7 changes: 6 additions & 1 deletion src/api-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ interface RequestConfig {

export class HTTPClient {
private static instance: HTTPClient;
private baseURL: string = 'https://parea-ai-backend-us-9ac16cdbc7a7b006.onporter.run/api/parea/v1';
private baseURL: string;
private apiKey: string | null = null;
private client: AxiosInstance;

Expand Down Expand Up @@ -39,6 +39,11 @@ export class HTTPClient {
this.apiKey = apiKey;
}

public setBaseURL(baseURL: string): void {
this.baseURL = baseURL;
this.client.defaults.baseURL = baseURL;
}

public async request(config: RequestConfig): Promise<AxiosResponse<any>> {
const headers = { 'x-api-key': this.apiKey || config.apiKey || '' };
try {
Expand Down
17 changes: 13 additions & 4 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ import {
CompletionResponse,
CreateExperimentRequest,
DataItem,
ExperimentOptions,
ExperimentSchema,
ExperimentStatsSchema,
FeedbackRequest,
FinishExperimentRequestSchema,
TestCaseCollection,
UseDeployedPrompt,
UseDeployedPromptResponse,
Expand Down Expand Up @@ -34,6 +36,9 @@ export class Parea {
this.apiKey = apiKey;
this.client = HTTPClient.getInstance();
this.client.setApiKey(this.apiKey);
this.client.setBaseURL(
process.env.PAREA_BASE_URL || 'https://parea-ai-backend-us-9ac16cdbc7a7b006.onporter.run/api/parea/v1',
);
pareaLogger.setClient(this.client);
pareaProject.setProjectName(projectName);
pareaProject.setClient(this.client);
Expand Down Expand Up @@ -106,10 +111,14 @@ export class Parea {
return response.data;
}

public async finishExperiment(experimentUUID: string): Promise<ExperimentStatsSchema> {
public async finishExperiment(
experimentUUID: string,
fin_req: FinishExperimentRequestSchema,
): Promise<ExperimentStatsSchema> {
const response = await this.client.request({
method: 'POST',
endpoint: EXPERIMENT_FINISHED_ENDPOINT.replace('{experiment_uuid}', experimentUUID),
data: fin_req,
});
return response.data;
}
Expand All @@ -125,11 +134,11 @@ export class Parea {
public experiment(
data: string | Iterable<DataItem>,
func: (...dataItem: any[]) => Promise<any>,
metadata?: { [key: string]: string },
options?: ExperimentOptions,
): Experiment {
if (typeof data === 'string') {
return new Experiment(data, func, '', this, metadata);
return new Experiment(data, func, '', this, options?.metadata, options?.datasetLevelEvalFuncs);
}
return new Experiment(data, func, '', this, metadata);
return new Experiment(data, func, '', this, options?.metadata, options?.datasetLevelEvalFuncs);
}
}
69 changes: 69 additions & 0 deletions src/cookbook/experiment_balanced_acc.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { Parea } from '../client';
import { trace } from '../utils/trace_utils';
import * as dotenv from 'dotenv';
import { EvaluatedLog, Log } from '../types';

dotenv.config();

const p = new Parea(process.env.PAREA_API_KEY);

function isCorrect(log: Log): number {
return log?.output === log.target ? 1 : 0;
}

const startsWithF = trace(
'startsWithF',
(name: string): string => {
if (name === 'Foo') {
return '1';
} else {
return '0';
}
},
{
evalFuncs: [isCorrect],
},
);

function balancedAccIsCorrect(logs: EvaluatedLog[]): number {
const scoreName: string = isCorrect.name;

const correct: Record<string, number> = {};
const total: Record<string, number> = {};

for (const log of logs) {
const evalResult = log?.scores?.find((score) => score.name === scoreName) || null;
const target: string = log.target || '';
if (evalResult !== null && target !== null) {
correct[target] = (correct[target] || 0) + (evalResult.score ? 1 : 0);
total[target] = (total[target] || 0) + 1;
}
}

const recalls: number[] = Object.keys(correct).map((key) => correct[key] / total[key]);

if (recalls.length === 0) {
return 0;
}

return recalls.reduce((acc, curr) => acc + curr, 0) / recalls.length;
}

export async function main() {
const e = p.experiment(
[
{ name: 'Foo', target: '1' },
{ name: 'Bar', target: '0' },
{ name: 'Far', target: '1' },
], // Data to run the experiment on (list of dicts)
startsWithF, // Function to run (callable),
{
datasetLevelEvalFuncs: [balancedAccIsCorrect],
},
);
return await e.run();
}

main().then(() => {
console.log('Experiment complete!');
});
2 changes: 1 addition & 1 deletion src/cookbook/experiment_random_numbers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import * as dotenv from 'dotenv';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

function isBetween1AndN(log: Log): number {
// Evaluates if the number is between 1 and n
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/experiment_using_saved_test_collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import * as dotenv from 'dotenv';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

function evalFunc(log: Log): number {
if (log.inputs?.['x']?.toLowerCase() === 'python') {
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/langchainExamples/tracing_langchain_simple.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { LLMChain } from 'langchain/chains';

dotenv.config();

new Parea(process.env.DEV_API_KEY);
new Parea(process.env.PAREA_API_KEY);
const handler = new PareaAILangchainTracer();
export const run = async () => {
const llm = new OpenAI({ temperature: 0 });
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/tracing_with_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Parea } from '../client';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

const LLM_OPTIONS = [
['gpt-3.5-turbo', 'openai'],
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/tracing_with_deployed_prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { trace } from '../utils/trace_utils';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

const deployedArgumentGenerator = async (query: string, additionalDescription: string = ''): Promise<string> => {
const completion: Completion = {
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/tracing_with_openai_endpoint_directly.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const openai = new OpenAI({
});

// needed for tracing
const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

// Patch OpenAI to add trace logs
patchOpenAI(openai);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const openai = new OpenAI({
});

// needed for tracing
new Parea(process.env.DEV_API_KEY);
new Parea(process.env.PAREA_API_KEY);

// Patch OpenAI to add trace logs
patchOpenAI(openai);
Expand Down
2 changes: 1 addition & 1 deletion src/cookbook/tracing_without_deployed_prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { getCurrentTraceId, trace } from '../utils/trace_utils';

dotenv.config();

const p = new Parea(process.env.DEV_API_KEY);
const p = new Parea(process.env.PAREA_API_KEY);

// If you want to log the inputs to the LLM call you can optionally add a trace wrappeer
const callLLM = trace(
Expand Down
41 changes: 38 additions & 3 deletions src/experiment/experiment.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import { DataItem, ExperimentStatsSchema, TestCaseCollection, TraceStatsSchema } from '../types';
import {
DataItem,
EvaluatedLog,
ExperimentStatsSchema,
EvaluationResult,
TestCaseCollection,
TraceStatsSchema,
} from '../types';
import { Parea } from '../client';
import { asyncPool } from '../helpers';
import { genRandomName } from './utils';
import { rootTraces } from '../utils/trace_utils';

function calculateAvgAsString(values: number[] | undefined): string {
if (!values || values.length === 0) {
Expand Down Expand Up @@ -51,6 +59,7 @@ async function experiment(
func: (...dataItem: any[]) => Promise<any>,
p: Parea,
metadata?: { [key: string]: string } | undefined,
datasetLevelEvalFuncs?: ((logs: EvaluatedLog[]) => Promise<number | null | undefined>)[],
maxParallelCalls: number = 10,
): Promise<ExperimentStatsSchema> {
if (typeof data === 'string') {
Expand Down Expand Up @@ -83,8 +92,29 @@ async function experiment(
void _;
}

const experimentStats: ExperimentStatsSchema = await p.finishExperiment(experimentUUID);
const datasetLevelEvalPromises: Promise<EvaluationResult | null>[] =
datasetLevelEvalFuncs?.map(async (func): Promise<EvaluationResult | null> => {
try {
const score = await func(Array.from(rootTraces.values()));
if (score !== undefined && score !== null) {
return { name: func.name, score };
}
} catch (e) {
console.error(`Error occurred calling '${func.name}', ${e}`, e);
}
return null;
}) || [];
const datasetLevelEvaluationResults = (await Promise.all(datasetLevelEvalPromises)).filter(
(x): x is EvaluationResult => x !== null,
);

const experimentStats: ExperimentStatsSchema = await p.finishExperiment(experimentUUID, {
dataset_level_stats: datasetLevelEvaluationResults,
});
const statNameToAvgStd = calculateAvgStdForExperiment(experimentStats);
datasetLevelEvaluationResults.forEach((result) => {
statNameToAvgStd[result.name] = result.score.toFixed(2);
});
console.log(`Experiment ${name} stats:\n${JSON.stringify(statNameToAvgStd, null, 2)}\n\n`);
console.log(`View experiment & its traces at: https://app.parea.ai/experiments/${experimentUUID}\n`);
return experimentStats;
Expand All @@ -97,19 +127,22 @@ export class Experiment {
p: Parea;
experimentStats?: ExperimentStatsSchema;
metadata?: { [key: string]: string };
datasetLevelEvalFuncs?: ((logs: EvaluatedLog[]) => Promise<number | null | undefined>)[];

constructor(
data: string | Iterable<DataItem>,
func: (...dataItem: any[]) => Promise<any>,
name: string,
p: Parea,
metadata?: { [key: string]: string },
datasetLevelEvalFuncs?: ((logs: EvaluatedLog[]) => Promise<number | null | undefined>)[],
) {
this.name = name;
this.data = data;
this.func = func;
this.p = p;
this.metadata = metadata;
this.datasetLevelEvalFuncs = datasetLevelEvalFuncs;
if (typeof data === 'string') {
if (!this.metadata) {
this.metadata = {};
Expand All @@ -125,7 +158,9 @@ export class Experiment {
async run(name: string | undefined = undefined): Promise<void> {
this.name = name || genRandomName();
this.experimentStats = new ExperimentStatsSchema(
(await experiment(this.name, this.data, this.func, this.p, this.metadata)).parent_trace_stats,
(
await experiment(this.name, this.data, this.func, this.p, this.metadata, this.datasetLevelEvalFuncs)
).parent_trace_stats,
);
}
}
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export {
Log,
TraceLog,
TraceLogTreeSchema,
NamedEvaluationScore,
EvaluationResult,
TraceOptions,
UpdateLog,
CreateExperimentRequest,
Expand Down
20 changes: 16 additions & 4 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export type TraceLogInputs = {
[key: string]: string;
};

export type NamedEvaluationScore = {
export type EvaluationResult = {
name: string;
score: number;
};
Expand All @@ -126,7 +126,11 @@ export type Log = {
cost?: number;
};

export type TraceLog = Log & {
export type EvaluatedLog = Log & {
scores?: EvaluationResult[];
};

export type TraceLog = EvaluatedLog & {
trace_id: string;
parent_trace_id?: string;
root_trace_id?: string;
Expand All @@ -139,7 +143,6 @@ export type TraceLog = Log & {
apply_eval_frac?: number;
cache_hit?: boolean;
evaluation_metric_names?: string[];
scores?: NamedEvaluationScore[];
feedback_score?: number;
trace_name?: string;
children: string[];
Expand Down Expand Up @@ -179,7 +182,7 @@ export type ExperimentSchema = CreateExperimentRequest & {
created_at: string;
};

export type EvaluationScoreSchema = NamedEvaluationScore & {
export type EvaluationScoreSchema = EvaluationResult & {
id?: number;
};

Expand Down Expand Up @@ -345,3 +348,12 @@ export class TestCaseCollection {
}));
}
}

export type FinishExperimentRequestSchema = {
dataset_level_stats?: EvaluationResult[];
};

export type ExperimentOptions = {
metadata?: { [key: string]: string };
datasetLevelEvalFuncs?: any[];
};
Loading

0 comments on commit 074f4c6

Please sign in to comment.