Skip to content

Commit

Permalink
Use mbedtls to support Socket TLS secureTransport: 'on' (#58)
Browse files Browse the repository at this point in the history
Closes #29.
  • Loading branch information
TooTallNate authored Nov 9, 2023
1 parent 98680d5 commit 047597d
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 58 deletions.
5 changes: 5 additions & 0 deletions .changeset/forty-scissors-jump.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'nxjs-runtime': patch
---

Use mbedtls to support Socket TLS `secureTransport: 'on'`
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ CXXFLAGS := $(CFLAGS) -fno-rtti -fno-exceptions
ASFLAGS := -g $(ARCH)
LDFLAGS = -specs=$(DEVKITPRO)/libnx/switch.specs -g $(ARCH) -Wl,-Map,$(notdir $*.map)

LIBS := -pthread `freetype-config --libs` `aarch64-none-elf-pkg-config cairo --libs` -lturbojpeg -lwebp -lquickjs -lm3 -lm
LIBS := -pthread -lmbedtls -lmbedx509 -lmbedcrypto `freetype-config --libs` `aarch64-none-elf-pkg-config cairo --libs` -lturbojpeg -lwebp -lquickjs -lm3 -lm

#---------------------------------------------------------------------------------
# list of directories containing libraries, this must be the top level containing
Expand Down
19 changes: 18 additions & 1 deletion packages/runtime/src/$.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { NetworkInfo } from './types';
import type { Callback } from './internal';
import type { Server } from './tcp';
import type { Server, TlsContextOpaque } from './tcp';
import type { MemoryDescriptor, Memory } from './wasm';
import type { VirtualKeyboard } from './navigator/virtual-keyboard';

Expand Down Expand Up @@ -52,6 +52,23 @@ export interface Init {
onAccept: (fd: number) => void
): Server;

// tls.c
tlsHandshake(
cb: Callback<TlsContextOpaque>,
fd: number,
hostname: string
): void;
tlsWrite(
cb: Callback<number>,
ctx: TlsContextOpaque,
data: ArrayBuffer
): void;
tlsRead(
cb: Callback<number>,
ctx: TlsContextOpaque,
buffer: ArrayBuffer
): void;

// wasm.c
wasmCallFunc(f: any, ...args: unknown[]): unknown;
wasmMemNew(descriptor: MemoryDescriptor): Memory;
Expand Down
2 changes: 2 additions & 0 deletions packages/runtime/src/internal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type { SocketOptions } from './types';

export const INTERNAL_SYMBOL = Symbol('Internal');

export type Opaque<T> = { __type: T };

export type Callback<T> = (err: Error | null, result: T) => void;

export type CallbackReturnType<T> = T extends (
Expand Down
3 changes: 1 addition & 2 deletions packages/runtime/src/switch.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { $ } from './$';
import { Canvas, CanvasRenderingContext2D, ctxInternal } from './canvas';
import { FontFaceSet } from './polyfills/font';
import { type Callback, INTERNAL_SYMBOL } from './internal';
import { type Callback, INTERNAL_SYMBOL, type Opaque } from './internal';
import { inspect } from './inspect';
import { bufferSourceToArrayBuffer, toPromise } from './utils';
import { setTimeout, clearTimeout } from './timers';
Expand All @@ -16,7 +16,6 @@ import type {
SocketOptions,
} from './types';

export type Opaque<T> = { __type: T };
export type CanvasRenderingContext2DState =
Opaque<'CanvasRenderingContext2DState'>;
export type FontFaceState = Opaque<'FontFaceState'>;
Expand Down
96 changes: 55 additions & 41 deletions packages/runtime/src/tcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,14 @@ import {
def,
toPromise,
} from './utils';
import type {
BufferSource,
SecureTransportKind,
SocketAddress,
SocketInfo,
} from './types';
import { INTERNAL_SYMBOL, type SocketOptionsInternal } from './internal';

interface SocketInternal {
fd: number;
opened: Deferred<SocketInfo>;
closed: Deferred<void>;
secureTransport: SecureTransportKind;
allowHalfOpen: boolean;
}
import type { BufferSource, SocketAddress, SocketInfo } from './types';
import {
INTERNAL_SYMBOL,
Opaque,
type SocketOptionsInternal,
} from './internal';

const socketInternal = new WeakMap<Socket, SocketInternal>();
export type TlsContextOpaque = Opaque<'TlsContext'>;

export function parseAddress(address: string): SocketAddress {
const firstColon = address.indexOf(':');
Expand All @@ -35,24 +26,6 @@ export function parseAddress(address: string): SocketAddress {
};
}

export function read(fd: number, buffer: BufferSource) {
const ab = bufferSourceToArrayBuffer(buffer);
return toPromise($.read, fd, ab);
}

export function write(fd: number, data: string | BufferSource) {
const d = typeof data === 'string' ? encoder.encode(data) : data;
const ab = bufferSourceToArrayBuffer(d);
return toPromise($.write, fd, ab);
}

/**
* Creates a TCP connection specified by the `hostname`
* and `port`.
*
* @param opts Object containing the `port` number and `hostname` (defaults to `127.0.0.1`) to connect to.
* @returns Promise that is fulfilled once the connection has been successfully established.
*/
export async function connect(opts: SocketAddress) {
const { hostname = '127.0.0.1', port } = opts;
const [ip] = await resolve(hostname);
Expand All @@ -62,6 +35,39 @@ export async function connect(opts: SocketAddress) {
return toPromise($.connect, ip, port);
}

function read(fd: number, buffer: BufferSource) {
const ab = bufferSourceToArrayBuffer(buffer);
return toPromise($.read, fd, ab);
}

function write(fd: number, data: BufferSource) {
const ab = bufferSourceToArrayBuffer(data);
return toPromise($.write, fd, ab);
}

function tlsHandshake(fd: number, hostname: string) {
return toPromise($.tlsHandshake, fd, hostname);
}

function tlsRead(ctx: TlsContextOpaque, buffer: BufferSource) {
const ab = bufferSourceToArrayBuffer(buffer);
return toPromise($.tlsRead, ctx, ab);
}

function tlsWrite(ctx: TlsContextOpaque, data: BufferSource) {
const ab = bufferSourceToArrayBuffer(data);
return toPromise($.tlsWrite, ctx, ab);
}

interface SocketInternal {
fd: number;
tls?: TlsContextOpaque;
opened: Deferred<SocketInfo>;
closed: Deferred<void>;
}

const socketInternal = new WeakMap<Socket, SocketInternal>();

/**
* The `Socket` class represents a TCP connection, from which you can
* read and write data. A socket begins in a _connected_ state (if the
Expand All @@ -72,7 +78,6 @@ export async function connect(opts: SocketAddress) {
export class Socket {
readonly readable: ReadableStream<Uint8Array>;
readonly writable: WritableStream<Uint8Array>;

readonly opened: Promise<SocketInfo>;
readonly closed: Promise<void>;

Expand All @@ -92,8 +97,6 @@ export class Socket {
fd: -1,
opened: new Deferred(),
closed: new Deferred(),
secureTransport,
allowHalfOpen,
};
socketInternal.set(this, i);
this.opened = i.opened.promise;
Expand All @@ -105,29 +108,40 @@ export class Socket {
if (i.opened.pending) {
await socket.opened;
}
const bytesRead = await read(i.fd, readBuffer);
const bytesRead = await (i.tls
? tlsRead(i.tls, readBuffer)
: read(i.fd, readBuffer));
//console.log('read %d bytes', bytesRead);
if (bytesRead === 0) {
controller.close();
if (!allowHalfOpen) {
socket.close();
}
return;
}
controller.enqueue(new Uint8Array(readBuffer, 0, bytesRead));
//controller.enqueue(new Uint8Array(readBuffer, 0, bytesRead));
controller.enqueue(new Uint8Array(readBuffer.slice(0, bytesRead)));
},
});
this.writable = new WritableStream({
async write(chunk, controller) {
async write(chunk) {
if (i.opened.pending) {
await socket.opened;
}
await write(i.fd, chunk);
const n = await (i.tls ? tlsWrite(i.tls, chunk) : write(i.fd, chunk));
//console.log('Wrote %d bytes', n);
},
});

connect(address)
.then((fd) => {
i.fd = fd;
if (secureTransport === 'on') {
return tlsHandshake(fd, address.hostname);
}
})
.then((tls) => {
i.tls = tls;
i.opened.resolve({
localAddress: '',
remoteAddress: '',
Expand Down
12 changes: 8 additions & 4 deletions packages/runtime/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ export function toPromise<
Func extends (cb: Callback<any>, ...args: any[]) => any
>(fn: Func, ...args: CallbackArguments<Func>) {
return new Promise<CallbackReturnType<Func>>((resolve, reject) => {
fn((err, result) => {
if (err) return reject(err);
resolve(result);
}, ...args);
try {
fn((err, result) => {
if (err) return reject(err);
resolve(result);
}, ...args);
} catch (err) {
reject(err);
}
});
}

Expand Down
2 changes: 2 additions & 0 deletions source/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "wasm.h"
#include "image.h"
#include "tcp.h"
#include "tls.h"
#include "poll.h"

#define LOG_FILENAME "nxjs-debug.log"
Expand Down Expand Up @@ -478,6 +479,7 @@ int main(int argc, char *argv[])
nx_init_dns(ctx, init_obj);
nx_init_nifm(ctx, init_obj);
nx_init_tcp(ctx, init_obj);
nx_init_tls(ctx, init_obj);
nx_init_swkbd(ctx, init_obj);
nx_init_wasm(ctx, init_obj);
JS_SetPropertyStr(ctx, global_obj, "$", init_obj);
Expand Down
9 changes: 1 addition & 8 deletions source/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@
#include "poll.h"
#include "error.h"

typedef struct
{
JSContext *context;
JSValue callback;
JSValue buffer;
} nx_js_callback_t;

void nx_on_connect(nx_poll_t *p, nx_connect_t *req)
{
nx_js_callback_t *req_cb = (nx_js_callback_t *)req->opaque;
Expand Down Expand Up @@ -42,8 +35,8 @@ void nx_on_connect(nx_poll_t *p, nx_connect_t *req)

JSValue nx_js_tcp_connect(JSContext *ctx, JSValueConst this_val, int argc, JSValueConst *argv)
{
const char *ip = JS_ToCString(ctx, argv[1]);
int port;
const char *ip = JS_ToCString(ctx, argv[1]);
if (!ip || JS_ToInt32(ctx, &port, argv[2]))
{
JS_ThrowTypeError(ctx, "invalid input");
Expand Down
Loading

0 comments on commit 047597d

Please sign in to comment.