Skip to content

Commit

Permalink
revert: add back sampler support
Browse files Browse the repository at this point in the history
This reverts commit 5b04db4.
  • Loading branch information
hans00 committed Oct 4, 2024
1 parent cd8412b commit 96c2a37
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 2 deletions.
3 changes: 3 additions & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ interface Binding {
moduleExecute(ptr: ExternalObject, method_name: string, inputs: InternalEValue[]): Promise<InternalEValue[]>;
moduleGetMethodMeta(ptr: ExternalObject, method_name: string): MethodMeta;
moduleMethodNames(ptr: ExternalObject): string[];
// sampler methods
createSampler(vocab_size: number, temperature: number, top_p: number, seed: number): ExternalObject;
samplerSample(ptr: ExternalObject, tensor: ExternalObject): number;
// tensor methods
createTensor(dtype: DType, shape: number[], data: ArrayBuffer): ExternalObject;
tensorGetDtype(ptr: ExternalObject): DType;
Expand Down
12 changes: 11 additions & 1 deletion lib/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import path from "path";
import { Module, Tensor, EValueTag, DType } from "./index";
import { Module, Tensor, Sampler, EValueTag, DType } from "./index";

const model = path.resolve(__dirname, "__fixtures__/mul.pte");

Expand Down Expand Up @@ -53,3 +53,13 @@ it("Tensor", async () => {
expect(reshaped.shape).toEqual([2, 3]);
expect(reshaped.data).toMatchSnapshot();
});

it("Sampler", async () => {
const mockTensor = new Tensor("float32", [1, 2, 10], Float32Array.from({ length: 20 }, (_, i) => i));
const sampler = new Sampler(10);

// sample
const sample = sampler.sample(mockTensor);
expect(sample).toBeGreaterThanOrEqual(0);
expect(sample).toBeLessThanOrEqual(10);
});
31 changes: 30 additions & 1 deletion lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,33 @@ class Module {
}
}

export { Tensor, Module };
class Sampler {
_ptr: ExternalObject;
_vocab_size: number;

constructor(vocab_size: number, temperature: number = 0.7, topP: number = 0.9, seed?: number) {
this._vocab_size = vocab_size;
this._ptr = mod.createSampler(
vocab_size,
temperature,
topP,
seed ?? Math.floor(Math.random() * 1000000)
);
}

sample(tensor: Tensor): number {
if (tensor.dtype !== "float32") {
throw new Error(`Unsupported dtype: ${tensor.dtype}`);
}
if (tensor.shape.length !== 3 || tensor.shape[0] !== 1 || tensor.shape[1] === 0 || tensor.shape[2] !== this._vocab_size) {
throw new Error(`Invalid shape: ${tensor.shape}`);
}
return mod.samplerSample(this._ptr, tensor._ptr);
}

dispose() {
delete this._ptr;
}
}

export { Sampler, Tensor, Module };
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ mod evalue_tag;
mod macros;
mod method_meta;
mod module;
mod sampler;
mod tensor;
mod tensor_type;

use neon::prelude::*;

#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
// Sampler
cx.export_function("createSampler", sampler::create)?;
cx.export_function("samplerSample", sampler::sample)?;
// Tensor
cx.export_function("createTensor", tensor::create)?;
cx.export_function("tensorGetDtype", tensor::get_dtype)?;
Expand Down
60 changes: 60 additions & 0 deletions src/sampler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::arg_get_value;
use crate::tensor::Tensor;
use crate::tensor_type::TensorType;
use cpp::{cpp, cpp_class};
use neon::prelude::*;
use neon::types::Finalize;

cpp! {{
#include <executorch/examples/models/llama2/sampler/sampler.h>
#include <executorch/examples/models/llama2/sampler/sampler.cpp>
}}

cpp_class!(pub unsafe struct Sampler as "torch::executor::Sampler");

impl Finalize for Sampler {}

impl Sampler {
pub fn new(vocab_size: i32, temperature: f32, topp: f32, rng_seed: u64) -> Self {
unsafe {
cpp!([vocab_size as "int", temperature as "float", topp as "float", rng_seed as "uint64_t"] -> Sampler as "torch::executor::Sampler" {
return torch::executor::Sampler(vocab_size, temperature, topp, rng_seed);
})
}
}

pub fn sample(&self, param: &[f32]) -> i32 {
let array: *const f32 = param.as_ptr();
unsafe {
cpp!([self as "torch::executor::Sampler*", array as "float *"] -> i32 as "int32_t" {
return self->sample<float>(array);
})
}
}
}

// JS interface

pub fn create(mut cx: FunctionContext) -> JsResult<JsBox<Sampler>> {
let vocab_size = arg_get_value!(cx, 0, JsNumber, i32);
let temperature = arg_get_value!(cx, 1, JsNumber, f32);
let topp = arg_get_value!(cx, 2, JsNumber, f32);
let rng_seed = arg_get_value!(cx, 3, JsNumber, u64);
Ok(cx.boxed(Sampler::new(vocab_size, temperature, topp, rng_seed)))
}

pub fn sample(mut cx: FunctionContext) -> JsResult<JsNumber> {
let sampler = cx.argument::<JsBox<Sampler>>(0)?;
let input = cx.argument::<JsBox<Tensor>>(1)?;
if input.dtype() != TensorType::Float32 {
return cx.throw_error("Input tensor must be of type Float32");
}
let shape = input.sizes();
if shape.len() != 3 || shape[0] != 1 {
return cx.throw_error("Input tensor must have shape [1, ?, N]");
}
let inputs = input.get_data::<f32>();
let slice_start: usize = ((shape[1] - 1) * shape[2]) as usize;
let slice_end: usize = (shape[1] * shape[2]) as usize;
Ok(cx.number(sampler.sample(inputs[slice_start..slice_end].as_ref())))
}

0 comments on commit 96c2a37

Please sign in to comment.