Skip to content

Commit

Permalink
refactor: almost done old API
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Jun 6, 2024
1 parent c163bc5 commit 7c6a7ad
Show file tree
Hide file tree
Showing 23 changed files with 1,766 additions and 259 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
target
*.node
*.so
*.dll
*.dylib
**/node_modules
**/.DS_Store
npm-debug.log
cargo.log
cross.log
lib/*.js
lib/*.js
bin/**/*
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ crate-type = ["cdylib"]
neon = { version = "1", default-features = false, features = ["napi-6"] }
cpp = "0.5"
cpp_macros = "0.5"
libc = "0.2"

[build-dependencies]
cpp_build = "0.5"
Expand Down
10 changes: 8 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@ fn link_lib(lib_path: &Path, lib: &str, whole_link: bool) -> Result<(), ()> {
}
}
return Ok(());
} else {
eprintln!("{} not found", filename);
}
Err(())
}

fn main() {
println!("cargo:rerun-if-changed=src/sampler.rs");
println!("cargo:rerun-if-changed=src/tensor.rs");
println!("cargo:rerun-if-changed=src/tensor.hpp");
println!("cargo:rerun-if-changed=src/module.rs");
println!("cargo:rerun-if-changed=src/module.hpp");
println!("cargo:rerun-if-changed=src/method_meta.rs");
println!("cargo:rerun-if-changed=src/evalue.rs");
println!("cargo:rerun-if-changed=src/evalue.hpp");
println!("cargo:rerun-if-changed=src/eterror.rs");
println!("cargo:rerun-if-changed=src/lib.rs");

let install_prefix = std::env::var("EXECUTORCH_INSTALL_PREFIX").unwrap_or_else(|_| "executorch/cmake-out".to_string());
let lib_path = Path::new(&install_prefix).join("lib");
Expand Down
67 changes: 67 additions & 0 deletions lib/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP

exports[`Module 1`] = `
Float32Array [
0,
0,
0,
0,
0,
0,
]
`;

exports[`Module 2`] = `
Float32Array [
1,
4,
9,
16,
25,
36,
]
`;

exports[`Tensor 1`] = `
Float32Array [
2,
4,
6,
]
`;

exports[`Tensor 2`] = `
Float32Array [
0,
4,
6,
]
`;

exports[`Tensor 3`] = `
Float32Array [
1,
2,
1,
2,
3,
4,
3,
4,
5,
6,
5,
6,
]
`;

exports[`Tensor 4`] = `
Float32Array [
1,
2,
3,
4,
5,
6,
]
`;
26 changes: 10 additions & 16 deletions lib/binding.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ExternalObject, InternalEValue, MethodMeta, TensorData, Optional, TensorPtrInfo } from "./types";
import type { ExternalObject, InternalEValue, MethodMeta, Optional, TensorPtrInfo, DType } from "./types";

interface Binding {
// module methods
Expand All @@ -7,25 +7,19 @@ interface Binding {
moduleExecute(ptr: ExternalObject, method_name: string, inputs: InternalEValue[]): Promise<InternalEValue[]>;
moduleGetMethodMeta(ptr: ExternalObject, method_name: string): MethodMeta;
moduleMethodNames(ptr: ExternalObject): string[];
moduleDispose(ptr: ExternalObject): void;
// sampler methods
createSampler(vocab_size: number, temperature: number, top_p: number, seed: number): ExternalObject;
samplerSample(ptr: ExternalObject, vector: Float32Array): number;
samplerSample(ptr: ExternalObject, tensor: ExternalObject): number;
// tensor methods
createU8Tensor(data: Uint8Array, shape: number[]): ExternalObject;
createI8Tensor(data: Int8Array, shape: number[]): ExternalObject;
createI16Tensor(data: Int16Array, shape: number[]): ExternalObject;
createI32Tensor(data: Int32Array, shape: number[]): ExternalObject;
createI64Tensor(data: BigInt64Array, shape: number[]): ExternalObject;
createF32Tensor(data: Float32Array, shape: number[]): ExternalObject;
createF64Tensor(data: Float64Array, shape: number[]): ExternalObject;
tensorGetData(ptr: ExternalObject): TensorData;
tensorSetData(ptr: ExternalObject, data: TensorData): void;
tensorConcat(ptrs: ExternalObject[], axis: number): TensorPtrInfo;
tensorSlice(ptr: ExternalObject, ...slice_position: Array<Optional<Array<Optional<number>>|number>>): TensorPtrInfo;
tensorReshape(ptr: ExternalObject, shape: number[]): void;
createTensor(dtype: DType, shape: number[], data: ArrayBuffer): ExternalObject;
tensorGetDtype(ptr: ExternalObject): DType;
tensorGetShape(ptr: ExternalObject): number[];
tensorGetData(ptr: ExternalObject): ArrayBuffer;
tensorSetData(ptr: ExternalObject, data: ArrayBuffer): void;
tensorSetValue(ptr: ExternalObject, position: Array<number>, data: number | boolean): void;
tensorDispose(ptr: ExternalObject): void;
tensorConcat(ptrs: ExternalObject[], axis: number): TensorPtrInfo;
tensorSlice(ptr: ExternalObject, slice_position: Array<Optional<Array<Optional<number>>|number>>): TensorPtrInfo;
tensorReshape(ptr: ExternalObject, shape: number[]): TensorPtrInfo;
}

const moduleBasePath = `../bin/${process.platform}/${process.arch}`;
Expand Down
110 changes: 55 additions & 55 deletions lib/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
import path from "path";
import { Module, Tensor, Sampler } from "./index";
import { Module, Tensor, Sampler, EValueTag, DType } from "./index";

const model = path.resolve(__dirname, "__fixtures__/mul.pte");

// it("Module", async () => {
// const mod = await Module.load(model);
// expect(mod.method_names).toEqual(["forward"]);
// expect(mod.getMethodMeta("forward")).toEqual({
// name: "forward",
// inputs: [
// { tag: "tensor", tensor_info: { dtype: "float32", shape: [3, 2] } },
// { tag: "tensor", tensor_info: { dtype: "float32", shape: [3, 2] } },
// ],
// outputs: [{ tag: "tensor", tensor_info: { dtype: "float32", shape: [3, 2] } }],
// });
// { // execute without inputs
// const outputs = await mod.execute("forward");
// expect(outputs[0]).toBeInstanceOf(Tensor);
// if (outputs[0] instanceof Tensor) {
// expect(outputs[0].dtype).toBe("float32");
// expect(outputs[0].shape).toEqual([3, 2]);
// expect(outputs[0].data).toMatchSnapshot();
// }
// }
// { // forward
// const input = new Tensor("float32", [3, 2], new Float32Array([1, 2, 3, 4, 5, 6]));
// const outputs = await mod.forward([input, input]);
// expect(outputs[0]).toBeInstanceOf(Tensor);
// if (outputs[0] instanceof Tensor) {
// expect(outputs[0].dtype).toBe("float32");
// expect(outputs[0].shape).toEqual([3, 2]);
// expect(outputs[0].data).toMatchSnapshot();
// }
// }
// });
it("Module", async () => {
const mod = await Module.load(model);
expect(mod.method_names).toEqual(["forward"]);
expect(mod.getMethodMeta("forward")).toEqual({
name: "forward",
inputs: [
{ tag: EValueTag.Tensor, tensor_info: { dtype: DType.float32, shape: [3, 2] } },
{ tag: EValueTag.Tensor, tensor_info: { dtype: DType.float32, shape: [3, 2] } },
],
outputs: [{ tag: EValueTag.Tensor, tensor_info: { dtype: DType.float32, shape: [3, 2] } }],
});
{ // execute without inputs
const outputs = await mod.execute("forward");
expect(outputs[0]).toBeInstanceOf(Tensor);
if (outputs[0] instanceof Tensor) {
expect(outputs[0].dtype).toBe("float32");
expect(outputs[0].shape).toEqual([3, 2]);
expect(outputs[0].data).toMatchSnapshot();
}
}
{ // forward
const input = new Tensor("float32", [3, 2], new Float32Array([1, 2, 3, 4, 5, 6]));
const outputs = await mod.forward([input, input]);
expect(outputs[0]).toBeInstanceOf(Tensor);
if (outputs[0] instanceof Tensor) {
expect(outputs[0].dtype).toBe("float32");
expect(outputs[0].shape).toEqual([3, 2]);
expect(outputs[0].data).toMatchSnapshot();
}
}
});

// it("Tensor", async () => {
// const input = new Tensor("float32", [3, 2], new Float32Array([1, 2, 3, 4, 5, 6]));
it("Tensor", async () => {
const input = new Tensor("float32", [3, 2], new Float32Array([1, 2, 3, 4, 5, 6]));

// // slice
// const slice = input.slice(null, [1, null]);
// expect(slice.dtype).toBe("float32");
// expect(slice.shape).toEqual([3, 1]);
// expect(slice.data).toMatchSnapshot();
// slice
const slice = input.slice(null, [1, null]);
expect(slice.dtype).toBe("float32");
expect(slice.shape).toEqual([3, 1]);
expect(slice.data).toMatchSnapshot();

// // setValue
// slice.setValue([0, 0], 0);
// expect(slice.data).toMatchSnapshot();
// setValue
slice.setValue([0, 0], 0);
expect(slice.data).toMatchSnapshot();

// // concat
// const concat = Tensor.concat([input, input], 1);
// expect(concat.dtype).toBe("float32");
// expect(concat.shape).toEqual([3, 4]);
// expect(concat.data).toMatchSnapshot();
// concat
const concat = Tensor.concat([input, input], 1);
expect(concat.dtype).toBe("float32");
expect(concat.shape).toEqual([3, 4]);
expect(concat.data).toMatchSnapshot();

// // reshape
// input.reshape([2, 3]);
// expect(input.dtype).toBe("float32");
// expect(input.shape).toEqual([2, 3]);
// expect(input.data).toMatchSnapshot();
// });
// reshape
const reshaped = input.reshape([2, 3]);
expect(reshaped.dtype).toBe("float32");
expect(reshaped.shape).toEqual([2, 3]);
expect(reshaped.data).toMatchSnapshot();
});

it("Sampler", async () => {
// const mockTensor = new Tensor("float32", [1, 2, 10], Float32Array.from({ length: 20 }, (_, i) => i));
const mockTensor = new Tensor("float32", [1, 2, 10], Float32Array.from({ length: 20 }, (_, i) => i));
const sampler = new Sampler(10);

// sample
const sample = sampler.sample([1,2,3,4,5,6,7,8,9,10]);
const sample = sampler.sample(mockTensor);
expect(sample).toBeGreaterThanOrEqual(0);
expect(sample).toBeLessThanOrEqual(10);
});
Loading

0 comments on commit 7c6a7ad

Please sign in to comment.