-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from parea-ai/PAI-646-dataset-support-in-ts
feat(datasets): add dataset endpoints
- Loading branch information
Showing
5 changed files
with
161 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import { Parea } from '../client'; | ||
import * as dotenv from 'dotenv'; | ||
|
||
dotenv.config(); | ||
|
||
const p = new Parea(process.env.PAREA_API_KEY); | ||
|
||
export async function main() { | ||
const data = [ | ||
{ | ||
problem: '1+2', | ||
target: 3, | ||
tags: ['easy'], | ||
}, | ||
{ problem: 'Solve the differential equation dy/dx = 3y.', target: 'y = c * e^(3x)', tags: ['hard'] }, | ||
]; | ||
|
||
// this will create a new dataset on Parea named "Math problems". | ||
// The dataset will have one column named "problem", and two columns using the reserved names "target" and "tags". | ||
// when using this dataset the expected prompt template should have a placeholder for the variable problem. | ||
p.createTestCollection(data, 'Math problems 2'); | ||
|
||
const new_data = [{ problem: 'Evaluate the integral ∫x^2 dx from 0 to 3.', target: 9, tags: ['hard'] }]; | ||
// this will add the new test cases to the existing "Math problems" dataset. | ||
// New test cases must have the same columns as the existing dataset. | ||
p.addTestCases(new_data, 'Math problems'); | ||
// Or if you can use the dataset ID instead of the name | ||
p.addTestCases(new_data, undefined, 121); | ||
} | ||
|
||
main().then(() => { | ||
console.log('Done!'); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import { genRandomName } from './utils'; | ||
import { CreateTestCase, CreateTestCaseCollection } from '../types'; | ||
|
||
/** | ||
* Create a test case collection from a dictionary of test cases. | ||
* | ||
* @param data = list of key-value pairs where keys represent input names. | ||
* Each item in the list represent a test case row. | ||
* Target and Tags are reserved keys. There can only be one target and tags key per dict item. | ||
* If target is present it will represent the target/expected response for the inputs. | ||
* If tags are present they must be a list of json_serializable values. | ||
* @param name - A unique name for the test collection. If not provided a random name will be generated. | ||
* @returns CreateTestCaseCollection | ||
*/ | ||
export async function createTestCollection( | ||
data: Record<string, any>[], | ||
name?: string, | ||
): Promise<CreateTestCaseCollection> { | ||
if (!name) { | ||
name = genRandomName(); | ||
} | ||
|
||
const columnNames = Array.from( | ||
new Set(data.flatMap((row) => Object.keys(row).filter((key) => key !== 'target' && key !== 'tags'))), | ||
); | ||
const testCases = await createTestCases(data); | ||
|
||
return { | ||
name, | ||
column_names: columnNames, | ||
test_cases: testCases, | ||
}; | ||
} | ||
|
||
/** | ||
* Create a list of test cases from a dictionary. | ||
* | ||
* @param data = list of key-value pairs where keys represent input names. | ||
* Each item in the list represent a test case row. | ||
* Target and Tags are reserved keys. There can only be one target and tags key per dict item. | ||
* If target is present it will represent the target/expected response for the inputs. | ||
* If tags are present they must be a list of json_serializable values. | ||
* @returns CreateTestCase[] | ||
*/ | ||
export async function createTestCases(data: Record<string, any>[]): Promise<CreateTestCase[]> { | ||
const testCases: CreateTestCase[] = []; | ||
|
||
data.forEach((row) => { | ||
const inputs: Record<string, string> = {}; | ||
let target: string | undefined; | ||
let tags: string[] = []; | ||
|
||
Object.entries(row).forEach(([k, v]) => { | ||
if (k === 'target') { | ||
if (target !== undefined) { | ||
console.warn('There can only be one target key per test case. Only the first target will be used.'); | ||
} | ||
target = JSON.stringify(v, null, 2); | ||
} else if (k === 'tags') { | ||
if (!Array.isArray(v)) { | ||
throw new Error('Tags must be a list of json serializable values.'); | ||
} | ||
if (tags.length > 0) { | ||
console.warn('There can only be one tags key per test case. Only the first set of tags will be used.'); | ||
} | ||
tags = v.map((tag) => (typeof tag === 'string' ? tag : JSON.stringify(tag, null, 2))); | ||
} else { | ||
inputs[k] = typeof v === 'string' ? v : JSON.stringify(v, null, 2); | ||
} | ||
}); | ||
|
||
testCases.push({ inputs, target, tags }); | ||
}); | ||
|
||
return testCases; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters