Skip to content

Commit

Permalink
WIP - cuda hanging sometimes...
Browse files Browse the repository at this point in the history
  • Loading branch information
enricozb committed May 30, 2024
1 parent 7f178a6 commit dea5d05
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/hvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ Port io_close_file(Net* net, Book* book, Port argm) {
// file descriptor and string to write.
Port io_write(Net* net, Book* book, Port argm) {
if (get_tag(peek(net, argm)) != CON) {
fprintf(stderr, "io_write: expected tuple, but got %u, port: %u\n", get_tag(peek(net, argm)), argm);
fprintf(stderr, "io_write: expected tuple, but got %u\n", get_tag(peek(net, argm)));
return new_port(ERA, 0);
}

Expand Down
166 changes: 102 additions & 64 deletions src/hvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

// Integers
// --------
Expand Down Expand Up @@ -1057,12 +1058,23 @@ __device__ void link_pair(Net* net, TM* tm, Pair AB) {

// Gets the necessary resources for an interaction.
__device__ bool get_resources(Net* net, TM* tm, u8 need_rbag, u8 need_node, u8 need_vars) {
// printf("get resources %u %u %u", need_rbag, need_node, need_vars);
// printf("tm: %p", tm);
// printf("", tm->page);

u32 got_rbag = min(RLEN - tm->rbag.lo_end, RLEN - tm->rbag.hi_end);
u32 got_node;
u32 got_vars;
if (tm->mode != WORK) {
// printf("get resources (not work): %u %u %u", need_rbag, need_node, need_vars);

got_node = g_node_alloc(net, tm, need_node);

// printf("got node");

got_vars = g_vars_alloc(net, tm, need_vars);

// printf("got vars");
} else {
got_node = l_node_alloc(net, tm, need_node);
got_vars = l_vars_alloc(net, tm, need_vars);
Expand Down Expand Up @@ -1444,6 +1456,65 @@ __global__ void boot_redex(GNet* gnet, Pair redex) {
}
}

/// Returns a λ-Encoded Ctr for a NIL: λt (t NIL)
__device__ Port nil_port(Net* net, TM* tm) {
if (!get_resources(net, tm, 0, 2, 1)) {
printf("nil_port: failed to get resources\n");
return new_port(ERA, 0);
}

vars_create(net, tm->vloc[0], NONE);
Port var = new_port(VAR, tm->vloc[0]);

node_create(net, tm->nloc[0], new_pair(new_port(NUM, new_u24(LIST_NIL)), var));
node_create(net, tm->nloc[1], new_pair(new_port(CON, tm->nloc[0]), var));

return new_port(CON, tm->nloc[1]);
}

/// Returns a λ-Encoded Ctr for a CONS: λt (((t CONS) head) tail)
__device__ Port cons_port(Net* net, TM* tm, Port head, Port tail) {
if (!get_resources(net, tm, 0, 4, 1)) {
return new_port(ERA, 0);
}

vars_create(net, tm->vloc[0], NONE);
Port var = new_port(VAR, tm->vloc[0]);

node_create(net, tm->nloc[0], new_pair(tail, var));
node_create(net, tm->nloc[1], new_pair(head, new_port(CON, tm->nloc[0])));
node_create(net, tm->nloc[2], new_pair(new_port(NUM, new_u24(LIST_CONS)), new_port(CON, tm->nloc[1])));
node_create(net, tm->nloc[3], new_pair(new_port(CON, tm->nloc[2]), var));

return new_port(CON, tm->nloc[3]);
}

// Converts a UTF-32 (truncated to 24 bits) string to a Port.
// Since unicode scalars can fit in 21 bits, HVM's u24
// integers can contain any unicode scalar value.
// Encoding:
// - λt (t NIL)
// - λt (((t CONS) head) tail)
__device__ Port str_to_port(Net* net, TM* tm, Str *str) {
Port port = nil_port(net, tm);

u32 len = str->text_len;
for (u32 i = 0; i < len; i++) {
Port chr = new_port(NUM, new_u24(str->text_buf[len - i - 1]));
port = cons_port(net, tm, chr, port);
}

return port;
}

__global__ void make_str_port(GNet* gnet, Str *str, Port* ret) {
if (GID() == 0) {
TM tm;
Net net = vnet_new(gnet, NULL, gnet->turn);
*ret = str_to_port(&net, &tm, str);
}
}

// Creates a node.
__global__ void make_node(GNet* gnet, Tag tag, Port fst, Port snd, Port* ret) {
if (GID() == 0) {
Expand Down Expand Up @@ -1839,7 +1910,7 @@ Port gnet_make_node(GNet* gnet, Tag tag, Port fst, Port snd) {

// Reads back a λ-Encoded constructor from device to host.
// Encoding: λt ((((t TAG) arg0) arg1) ...)
Ctr port_to_ctr(GNet* gnet, Port port) {
Ctr gnet_port_to_ctr(GNet* gnet, Port port) {
Ctr ctr;
ctr.tag = -1;
ctr.args_len = 0;
Expand Down Expand Up @@ -1877,7 +1948,7 @@ Ctr port_to_ctr(GNet* gnet, Port port) {
// Encoding:
// - λt (t NIL)
// - λt (((t CONS) head) tail)
Str port_to_str(GNet* gnet, Port port) {
Str gnet_port_to_str(GNet* gnet, Port port) {
// Result
Str str;
str.text_len = 0;
Expand All @@ -1888,7 +1959,7 @@ Str port_to_str(GNet* gnet, Port port) {
gnet_normalize(gnet);

// Reads the λ-Encoded Ctr
Ctr ctr = port_to_ctr(gnet, gnet_peek(gnet, port));
Ctr ctr = gnet_port_to_ctr(gnet, gnet_peek(gnet, port));

// Reads string layer
switch (ctr.tag) {
Expand All @@ -1914,62 +1985,27 @@ Str port_to_str(GNet* gnet, Port port) {
return str;
}

/// Returns a λ-Encoded Ctr for a NIL: λt (t NIL)
Port nil_port(GNet* gnet) {
TM tm;
Net net = vnet_new(gnet, NULL, gnet->turn);

if (!get_resources(net, tm[0], 0, 2, 1)) {
fprintf(stderr, "nil_port: failed to get resources\n");
return new_port(ERA, 0);
}

vars_create(net, tm[0]->vloc[0], NONE);
Port var = new_port(VAR, tm[0]->vloc[0]);

node_create(net, tm[0]->nloc[0], new_pair(new_port(NUM, new_u24(LIST_NIL)), var));
node_create(net, tm[0]->nloc[1], new_pair(new_port(CON, tm[0]->nloc[0]), var));

return new_port(CON, tm[0]->nloc[1]);
}

/// Returns a λ-Encoded Ctr for a CONS: λt (((t CONS) head) tail)
Port cons_port(Net* net, Port head, Port tail) {
TM tm;
Net net = vnet_new(gnet, NULL, gnet->turn);

if (!get_resources(net, tm[0], 0, 4, 1)) {
fprintf(stderr, "cons_port: failed to get resources\n");
return new_port(ERA, 0);
}

vars_create(net, tm[0]->vloc[0], NONE);
Port var = new_port(VAR, tm[0]->vloc[0]);

node_create(net, tm[0]->nloc[0], new_pair(tail, var));
node_create(net, tm[0]->nloc[1], new_pair(head, new_port(CON, tm[0]->nloc[0])));
node_create(net, tm[0]->nloc[2], new_pair(new_port(NUM, new_u24(LIST_CONS)), new_port(CON, tm[0]->nloc[1])));
node_create(net, tm[0]->nloc[3], new_pair(new_port(CON, tm[0]->nloc[2]), var));

return new_port(CON, tm[0]->nloc[3]);
}

// Converts a UTF-32 (truncated to 24 bits) string to a Port.
// Since unicode scalars can fit in 21 bits, HVM's u24
// integers can contain any unicode scalar value.
// Encoding:
// - λt (t NIL)
// - λt (((t CONS) head) tail)
Port str_to_port(Net* net, Str *str) {
Port port = nil_port(net);
Port gnet_make_str(GNet* gnet, Str *str) {
Port* d_ret;
cudaMalloc(&d_ret, sizeof(Port));

u32 len = str->text_len;
for (u32 i = 0; i < len; i++) {
Port chr = new_port(NUM, new_u24(str->text_buf[len - i - 1]));
port = cons_port(net, chr, port);
}
Str* cu_str;
cudaMalloc(&cu_str, sizeof(Str));
cudaMemcpy(cu_str, str, sizeof(Str), cudaMemcpyHostToDevice);

return port;
make_str_port<<<1,1>>>(gnet, cu_str, d_ret);

Port ret;
cudaMemcpy(&ret, d_ret, sizeof(Port), cudaMemcpyDeviceToHost);
cudaFree(d_ret);

return ret;
}

// Primitive IO Fns
Expand Down Expand Up @@ -2006,7 +2042,7 @@ FILE* port_to_file(Port port) {
}

// Reads a single char from `argm`.
Port io_read_char(GNet* gnet, Book* book, Port argm) {
Port io_read_char(GNet* gnet, Port argm) {
FILE* fp = port_to_file(gnet_peek(gnet, argm));
if (fp == NULL) {
return new_port(ERA, 0);
Expand All @@ -2019,11 +2055,11 @@ Port io_read_char(GNet* gnet, Book* book, Port argm) {
str.text_buf[1] = 0;
str.text_len = 1;

return str_to_port(gnet, &str);
return gnet_make_str(gnet, &str);
}

// Reads from `argm` at most 255 characters or until a newline is seen.
Port io_read_line(GNet* gnet, Book* book, Port argm) {
Port io_read_line(GNet* gnet, Port argm) {
FILE* fp = port_to_file(gnet_peek(gnet, argm));
if (fp == NULL) {
fprintf(stderr, "io_read_line: invalid file descriptor\n");
Expand All @@ -2044,22 +2080,24 @@ Port io_read_line(GNet* gnet, Book* book, Port argm) {
str.text_len--;
}

printf("read this string: '%s'\n", str.text_buf);

// Convert it to a port.
return str_to_port(gnet, &str);
return gnet_make_str(gnet, &str);
}

// Opens a file with the provided mode.
// `argm` is a tuple (CON node) of the
// file name and mode as strings.
Port io_open_file(GNet* gnet, Book* book, Port argm) {
Port io_open_file(GNet* gnet, Port argm) {
if (get_tag(gnet_peek(gnet, argm)) != CON) {
fprintf(stderr, "io_open_file: expected tuple\n");
return new_port(ERA, 0);
}

Pair args = gnet_node_load(gnet, get_val(argm));
Str name = port_to_str(gnet, book, get_fst(args));
Str mode = port_to_str(gnet, book, get_snd(args));
Str name = gnet_port_to_str(gnet, get_fst(args));
Str mode = gnet_port_to_str(gnet, get_snd(args));

for (u32 fd = 3; fd < sizeof(FILE_POINTERS); fd++) {
if (FILE_POINTERS[fd] == NULL) {
Expand All @@ -2074,7 +2112,7 @@ Port io_open_file(GNet* gnet, Book* book, Port argm) {
}

// Closes a file, reclaiming the file descriptor.
Port io_close_file(GNet* gnet, Book* book, Port argm) {
Port io_close_file(GNet* gnet, Port argm) {
FILE* fp = port_to_file(gnet_peek(gnet, argm));
if (fp == NULL) {
fprintf(stderr, "io_close_file: failed to close\n");
Expand All @@ -2095,15 +2133,15 @@ Port io_close_file(GNet* gnet, Book* book, Port argm) {
// Writes a string to a file.
// `argm` is a tuple (CON node) of the
// file descriptor and string to write.
Port io_write(GNet* gnet, Book* book, Port argm) {
Port io_write(GNet* gnet, Port argm) {
if (get_tag(gnet_peek(gnet, argm)) != CON) {
fprintf(stderr, "io_write: expected tuple, but got %u, port: %u\n", get_tag(peek(net, argm)), argm);
fprintf(stderr, "io_write: expected tuple, but got %u", get_tag(gnet_peek(gnet, argm)));
return new_port(ERA, 0);
}

Pair args = gnet_node_load(gnet, get_val(argm));
FILE* fp = port_to_file(gnet_peek(gnet, get_fst(args)));
Str str = port_to_str(gnet, book, get_snd(args));
Str str = gnet_port_to_str(gnet, get_snd(args));

if (fp == NULL) {
fprintf(stderr, "io_write: invalid file descriptor\n");
Expand Down Expand Up @@ -2160,7 +2198,7 @@ void do_run_io(GNet* gnet, Book* book, Port port) {
gnet_normalize(gnet);

// Reads the λ-Encoded Ctr
Ctr ctr = port_to_ctr(gnet, gnet_peek(gnet, port));
Ctr ctr = gnet_port_to_ctr(gnet, gnet_peek(gnet, port));

// Checks if IO Magic Number is a CON
if (get_tag(ctr.args_buf[0]) != CON) {
Expand All @@ -2176,7 +2214,7 @@ void do_run_io(GNet* gnet, Book* book, Port port) {

switch (ctr.tag) {
case IO_CALL: {
Str func = port_to_str(gnet, ctr.args_buf[1]);
Str func = gnet_port_to_str(gnet, ctr.args_buf[1]);
FFn* ffn = NULL;
// FIXME: optimize this linear search
for (u32 fid = 0; fid < book->ffns_len; ++fid) {
Expand Down

0 comments on commit dea5d05

Please sign in to comment.