Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Compatibility with @lala/appraisal #52

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mnist/predict.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ let correct = 0;
for (const test of testSet) {
const prediction = argmax(
await network.predict(
tensor(test.inputs.data, [1, ...test.inputs.shape] as Shape[keyof Shape]),
tensor(test.inputs.data, [1, ...test.inputs.shape] as Shape<Rank>),
),
);
const expected = argmax(test.outputs as Tensor<Rank>);
Expand Down
14 changes: 7 additions & 7 deletions src/backends/cpu/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import {
*/
export class CPUBackend implements Backend {
library: Library;
outputShape: Shape[Rank];
outputShape: Shape<Rank>;
#id: bigint;

constructor(
library: Library,
outputShape: Shape[Rank],
outputShape: Shape<Rank>,
id: bigint,
) {
this.library = library;
Expand All @@ -39,7 +39,7 @@ export class CPUBackend implements Backend {
) as bigint;
const outputShape = Array.from(
new Uint32Array(shape.buffer.slice(4).buffer),
) as Shape[Rank];
) as Shape<Rank>;
return new CPUBackend(library, outputShape, id);
}

Expand Down Expand Up @@ -72,13 +72,13 @@ export class CPUBackend implements Backend {
async predict(
input: Tensor<Rank>,
layers: number[],
outputShape: Shape[keyof Shape],
outputShape: Shape<Rank>,
): Promise<Tensor<Rank>>;
//deno-lint-ignore require-await
async predict(
input: Tensor<Rank>,
layers?: number[],
outputShape?: Shape[keyof Shape],
outputShape?: Shape<Rank>,
): Promise<Tensor<Rank>> {
const options = encodeJSON({
inputShape: input.shape,
Expand All @@ -100,7 +100,7 @@ export class CPUBackend implements Backend {
[
input.shape[0],
...(outputShape ?? this.outputShape),
] as Shape[keyof Shape],
] as Shape<Rank>,
);
}

Expand All @@ -121,7 +121,7 @@ export class CPUBackend implements Backend {
buffer.length,
shape.allocBuffer,
) as bigint;
const outputShape = Array.from(shape.buffer.slice(1)) as Shape[Rank];
const outputShape = Array.from(shape.buffer.slice(1)) as Shape<Rank>;

return new CPUBackend(library, outputShape, id);
}
Expand Down
8 changes: 4 additions & 4 deletions src/backends/cpu/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ export class Buffer {
*/
export type TrainOptions = {
datasets: number;
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
epochs: number;
batches: number;
rate: number;
Expand All @@ -31,8 +31,8 @@ export type TrainOptions = {
* Predict Options Interface.
*/
export type PredictOptions = {
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
};

/**
Expand Down
14 changes: 7 additions & 7 deletions src/backends/gpu/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import {
*/
export class GPUBackend implements Backend {
library: Library;
outputShape: Shape[Rank];
outputShape: Shape<Rank>;
#id: bigint;

constructor(library: Library, outputShape: Shape[Rank], id: bigint) {
constructor(library: Library, outputShape: Shape<Rank>, id: bigint) {
this.library = library;
this.outputShape = outputShape;
this.#id = id;
Expand All @@ -33,7 +33,7 @@ export class GPUBackend implements Backend {
buffer.length,
shape.allocBuffer,
) as bigint;
const outputShape = Array.from(shape.buffer.slice(1)) as Shape[Rank];
const outputShape = Array.from(shape.buffer.slice(1)) as Shape<Rank>;

return new GPUBackend(library, outputShape, id);
}
Expand Down Expand Up @@ -67,13 +67,13 @@ export class GPUBackend implements Backend {
async predict(
input: Tensor<Rank>,
layers: number[],
outputShape: Shape[keyof Shape],
outputShape: Shape<Rank>,
): Promise<Tensor<Rank>>;
//deno-lint-ignore require-await
async predict(
input: Tensor<Rank>,
layers?: number[],
outputShape?: Shape[keyof Shape],
outputShape?: Shape<Rank>,
): Promise<Tensor<Rank>> {
const options = encodeJSON({
inputShape: input.shape,
Expand All @@ -95,7 +95,7 @@ export class GPUBackend implements Backend {
[
input.shape[0],
...(outputShape ?? this.outputShape),
] as Shape[keyof Shape],
] as Shape<Rank>,
);
}

Expand All @@ -116,7 +116,7 @@ export class GPUBackend implements Backend {
buffer.length,
shape.allocBuffer,
) as bigint;
const outputShape = Array.from(shape.buffer.slice(1)) as Shape[Rank];
const outputShape = Array.from(shape.buffer.slice(1)) as Shape<Rank>;

return new GPUBackend(library, outputShape, id);
}
Expand Down
8 changes: 4 additions & 4 deletions src/backends/gpu/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export class Buffer {
*/
export type TrainOptions = {
datasets: number;
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
epochs: number;
batches: number;
rate: number;
Expand All @@ -28,8 +28,8 @@ export type TrainOptions = {
* Predict Options Interface.
*/
export type PredictOptions = {
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
};

/**
Expand Down
10 changes: 5 additions & 5 deletions src/backends/wasm/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ import { PredictOptions, TrainOptions } from "./utils.ts";
* Web Assembly Backend.
*/
export class WASMBackend implements Backend {
outputShape: Shape[Rank];
outputShape: Shape<Rank>;
#id: number;

constructor(outputShape: Shape[Rank], id: number) {
constructor(outputShape: Shape<Rank>, id: number) {
this.outputShape = outputShape;
this.#id = id;
}

static create(config: NetworkConfig): WASMBackend {
const shape = Array(0);
const id = wasm_backend_create(JSON.stringify(config), shape);
return new WASMBackend(shape as Shape[Rank], id);
return new WASMBackend(shape as Shape<Rank>, id);
}

train(datasets: DataSet[], epochs: number, batches: number, rate: number): void {
this.outputShape = datasets[0].outputs.shape.slice(1) as Shape[Rank];
this.outputShape = datasets[0].outputs.shape.slice(1) as Shape<Rank>;
const buffer = [];
for (const dataset of datasets) {
buffer.push(dataset.inputs.data as Float32Array);
Expand Down Expand Up @@ -76,6 +76,6 @@ export class WASMBackend implements Backend {
static load(input: Uint8Array): WASMBackend {
const shape = Array(0);
const id = wasm_backend_load(input, shape);
return new WASMBackend(shape as Shape[Rank], id);
return new WASMBackend(shape as Shape<Rank>, id);
}
}
8 changes: 4 additions & 4 deletions src/backends/wasm/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import { Rank, Shape } from "../../core/api/shape.ts";
*/
export type TrainOptions = {
datasets: number;
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
epochs: number;
batches: number;
rate: number;
Expand All @@ -16,6 +16,6 @@ export type TrainOptions = {
* Predict Options Interface.
*/
export type PredictOptions = {
inputShape: Shape[Rank];
outputShape: Shape[Rank];
inputShape: Shape<Rank>;
outputShape: Shape<Rank>;
};
4 changes: 2 additions & 2 deletions src/core/api/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export class IncompatibleRankError extends Error {
* Invalid Flatten Error is thrown when a tensor cannot be flattened.
*/
export class InvalidFlattenError extends Error {
constructor(input: Shape[Rank], output: Shape[Rank]) {
constructor(input: Shape<Rank>, output: Shape<Rank>) {
super(`Cannot flatten tensor of shape ${input} to shape ${output}`);
}
}
Expand All @@ -36,7 +36,7 @@ export class NoBackendError extends Error {
* Invalid Pool Error is thrown when a tensor cannot be pooled.
*/
export class InvalidPoolError extends Error {
constructor(size: Shape[Rank], stride: Shape2D) {
constructor(size: Shape<Rank>, stride: Shape2D) {
super(`Cannot pool shape ${size} with stride ${stride}`);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/api/layer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ export type FlattenLayerConfig = {
/**
* The size of the layer.
*/
size: Shape[Rank];
size: Shape<Rank>;
};

/**
Expand Down
28 changes: 10 additions & 18 deletions src/core/api/shape.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
/**
* Shape Type
*/
export type Shape<R extends Rank> = [number, ...number[]] & { length: R };
/**
* 1st dimentional shape.
*/
export type Shape1D = [number];
export type Shape1D = Shape<1>;

/**
* 2nd dimentional shape.
*/
export type Shape2D = [number, number];
export type Shape2D = Shape<2>;

/**
* 3th dimentional shape.
*/
export type Shape3D = [number, number, number];
export type Shape3D = Shape<3>;

/**
* 4th dimentional shape.
*/
export type Shape4D = [number, number, number, number];
export type Shape4D = Shape<4>;

/**
* 5th dimentional shape.
*/
export type Shape5D = [number, number, number, number, number];
export type Shape5D = Shape<5>;

/**
* 6th dimentional shape.
*/
export type Shape6D = [number, number, number, number, number, number];
export type Shape6D = Shape<6>;

/**
* Rank Types.
Expand Down Expand Up @@ -63,18 +67,6 @@ export enum Rank {
R6 = 6,
}

/**
* Shape Interface
*/
export interface Shape {
1: Shape1D;
2: Shape2D;
3: Shape3D;
4: Shape4D;
5: Shape5D;
6: Shape6D;
}

/**
* Array Map Types.
*/
Expand Down
37 changes: 30 additions & 7 deletions src/core/tensor/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,36 @@ import {
} from "../api/shape.ts";
import { inferShape, length } from "./util.ts";

export type TensorLike<R extends Rank> = {
shape: Shape<R>;
data: Float32Array;
};

/**
* A generic N-dimensional tensor.
*/
export class Tensor<R extends Rank> {
shape: Shape[R];
shape: Shape<R>;
data: Float32Array;

constructor(data: Float32Array, shape: Shape[R]) {
constructor(data: Float32Array, shape: Shape<R>) {
this.shape = shape;
this.data = data;
}

/**
* Creates an empty tensor.
*/
static zeroes<R extends Rank>(shape: Shape[R]): Tensor<R> {
static zeroes<R extends Rank>(shape: Shape<R>): Tensor<R> {
return new Tensor(new Float32Array(length(shape)), shape);
}

/**
* Serialise a tensor into JSON.
*/
toJSON(): { data: number[]; shape: Shape[R] } {
toJSON(): { data: number[]; shape: Shape<R> } {
const data = new Array(this.data.length).fill(1);
this.data.forEach((value, i) => data[i] = value);
this.data.forEach((value, i) => (data[i] = value));
return { data, shape: this.shape };
}
}
Expand All @@ -51,11 +56,29 @@ export class Tensor<R extends Rank> {
* tensor([1, 2, 3, 4], [2, 2]);
* ```
*/
export function tensor<R extends Rank>(tensorLike: TensorLike<R>): Tensor<R>;
export function tensor<R extends Rank>(
values: Float32Array,
shape: Shape[R],
shape: Shape<R>
): Tensor<R>;
export function tensor<R extends Rank>(
values: Float32Array | TensorLike<R>,
shape?: Shape<R>
): Tensor<R> {
return new Tensor(values, shape);
if (values instanceof Float32Array) {
if (!shape)
throw new Error("Cannot initialize Tensor without a shape parameter.");
return new Tensor(values, shape);
}
if (!values.data || !values.shape)
throw new Error(
`Cannot initialize Tensor: Expected keys 'data', 'shape'. Got ${Object.keys(
values
)
.map((x) => `'${x}'`)
.join(", ")}.`
);
return new Tensor(values.data, values.shape);
}

/**
Expand Down
Loading
Loading