Skip to content

Commit

Permalink
fix: drop support for Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Sep 25, 2024
1 parent df7f99b commit 5b04db4
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 97 deletions.
3 changes: 0 additions & 3 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ interface Binding {
moduleExecute(ptr: ExternalObject, method_name: string, inputs: InternalEValue[]): Promise<InternalEValue[]>;
moduleGetMethodMeta(ptr: ExternalObject, method_name: string): MethodMeta;
moduleMethodNames(ptr: ExternalObject): string[];
// sampler methods
createSampler(vocab_size: number, temperature: number, top_p: number, seed: number): ExternalObject;
samplerSample(ptr: ExternalObject, tensor: ExternalObject): number;
// tensor methods
createTensor(dtype: DType, shape: number[], data: ArrayBuffer): ExternalObject;
tensorGetDtype(ptr: ExternalObject): DType;
Expand Down
31 changes: 1 addition & 30 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,33 +170,4 @@ class Module {
}
}

class Sampler {
_ptr: ExternalObject;
_vocab_size: number;

constructor(vocab_size: number, temperature: number = 0.7, topP: number = 0.9, seed?: number) {
this._vocab_size = vocab_size;
this._ptr = mod.createSampler(
vocab_size,
temperature,
topP,
seed ?? Math.floor(Math.random() * 1000000)
);
}

sample(tensor: Tensor): number {
if (tensor.dtype !== "float32") {
throw new Error(`Unsupported dtype: ${tensor.dtype}`);
}
if (tensor.shape.length !== 3 || tensor.shape[0] !== 1 || tensor.shape[1] === 0 || tensor.shape[2] !== this._vocab_size) {
throw new Error(`Invalid shape: ${tensor.shape}`);
}
return mod.samplerSample(this._ptr, tensor._ptr);
}

dispose() {
delete this._ptr;
}
}

export { Sampler, Tensor, Module };
export { Tensor, Module };
4 changes: 0 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@ mod evalue_tag;
mod macros;
mod method_meta;
mod module;
mod sampler;
mod tensor;
mod tensor_type;

use neon::prelude::*;

#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
// Sampler
cx.export_function("createSampler", sampler::create)?;
cx.export_function("samplerSample", sampler::sample)?;
// Tensor
cx.export_function("createTensor", tensor::create)?;
cx.export_function("tensorGetDtype", tensor::get_dtype)?;
Expand Down
60 changes: 0 additions & 60 deletions src/sampler.rs

This file was deleted.

0 comments on commit 5b04db4

Please sign in to comment.