Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add match statement #621

Merged
merged 15 commits into from
Jan 15, 2025
31 changes: 31 additions & 0 deletions bootstrap_compiler/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,33 @@ static void build_loop(
add_jump(st, NULL, condblock, condblock, doneblock);
}

static void build_match_statament(struct State *st, const AstMatchStatement *match_stmt)
{
const LocalVariable *match_obj_enum = build_expression(st, &match_stmt->match_obj);
LocalVariable *match_obj_int = add_local_var(st, intType);
add_unary_op(st, match_stmt->match_obj.location, CF_ENUM_TO_INT32, match_obj_enum, match_obj_int);

CfBlock *done = add_block(st);
for (int i = 0; i < match_stmt->ncases; i++) {
for (AstExpression *caseobj = match_stmt->cases[i].case_objs; caseobj < &match_stmt->cases[i].case_objs[match_stmt->cases[i].n_case_objs]; caseobj++) {
const LocalVariable *case_obj_enum = build_expression(st, caseobj);
LocalVariable *case_obj_int = add_local_var(st, intType);
add_unary_op(st, caseobj->location, CF_ENUM_TO_INT32, case_obj_enum, case_obj_int);

const LocalVariable *cond = build_binop(st, AST_EXPR_EQ, caseobj->location, match_obj_int, case_obj_int, boolType);
CfBlock *then = add_block(st);
CfBlock *otherwise = add_block(st);

add_jump(st, cond, then, otherwise, then);
build_body(st, &match_stmt->cases[i].body);
add_jump(st, NULL, done, done, otherwise);
}
}

build_body(st, &match_stmt->case_underscore);
add_jump(st, NULL, done, done, done);
}

static void build_statement(struct State *st, const AstStatement *stmt)
{
switch(stmt->kind) {
Expand All @@ -891,6 +918,10 @@ static void build_statement(struct State *st, const AstStatement *stmt)
&stmt->data.forloop.body);
break;

case AST_STMT_MATCH:
build_match_statament(st, &stmt->data.match);
break;

case AST_STMT_BREAK:
if (!st->breakstack.len)
fail(stmt->location, "'break' can only be used inside a loop");
Expand Down
12 changes: 12 additions & 0 deletions bootstrap_compiler/free.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ void free_ast_statement(const AstStatement *stmt)
free(stmt->data.forloop.incr);
free_ast_body(&stmt->data.forloop.body);
break;
case AST_STMT_MATCH:
free_expression(&stmt->data.match.match_obj);
for (int i = 0; i < stmt->data.match.ncases; i++) {
for (AstExpression *caseobj = stmt->data.match.cases[i].case_objs; caseobj < &stmt->data.match.cases[i].case_objs[stmt->data.match.cases[i].n_case_objs]; caseobj++) {
free_expression(caseobj);
}
free(stmt->data.match.cases[i].case_objs);
free_ast_body(&stmt->data.match.cases[i].body);
}
free(stmt->data.match.cases);
free_ast_body(&stmt->data.match.case_underscore);
break;
case AST_STMT_ASSERT:
free_expression(&stmt->data.assertion.condition);
free(stmt->data.assertion.condition_str);
Expand Down
15 changes: 15 additions & 0 deletions bootstrap_compiler/jou_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ typedef struct AstConditionAndBody AstConditionAndBody;
typedef struct AstExpression AstExpression;
typedef struct AstAssignment AstAssignment;
typedef struct AstForLoop AstForLoop;
typedef struct AstCase AstCase;
typedef struct AstMatchStatement AstMatchStatement;
typedef struct AstNameTypeValue AstNameTypeValue;
typedef struct AstIfStatement AstIfStatement;
typedef struct AstStatement AstStatement;
Expand Down Expand Up @@ -264,6 +266,17 @@ struct AstForLoop {
AstStatement *incr;
AstBody body;
};
struct AstCase {
AstExpression *case_objs;
int n_case_objs;
AstBody body;
};
struct AstMatchStatement {
AstExpression match_obj;
AstCase *cases;
int ncases;
AstBody case_underscore;
};
struct AstIfStatement {
AstConditionAndBody *if_and_elifs;
int n_if_and_elifs; // Always >= 1 for the initial "if"
Expand Down Expand Up @@ -327,6 +340,7 @@ struct AstStatement {
AST_STMT_IF,
AST_STMT_WHILE,
AST_STMT_FOR,
AST_STMT_MATCH,
AST_STMT_BREAK,
AST_STMT_CONTINUE,
AST_STMT_DECLARE_LOCAL_VAR,
Expand All @@ -350,6 +364,7 @@ struct AstStatement {
AstConditionAndBody whileloop;
AstIfStatement ifstatement;
AstForLoop forloop;
AstMatchStatement match;
AstNameTypeValue vardecl;
AstAssignment assignment; // also used for inplace operations
AstFunction function;
Expand Down
47 changes: 47 additions & 0 deletions bootstrap_compiler/parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ static void validate_expression_statement(const AstExpression *expr)
}
}

static void parse_start_of_body(ParserState *ps);
static AstBody parse_body(ParserState *ps);

static AstIfStatement parse_if_statement(ParserState *ps)
Expand Down Expand Up @@ -701,6 +702,49 @@ static AstIfStatement parse_if_statement(ParserState *ps)
};
}

static AstMatchStatement parse_match_statement(ParserState *ps)
{
assert(is_keyword(ps->tokens, "match"));
ps->tokens++;

AstMatchStatement result = {.match_obj = parse_expression(ps)};
parse_start_of_body(ps);

while (ps->tokens->type != TOKEN_DEDENT) {
assert(is_keyword(ps->tokens, "case"));
ps->tokens++;

if (ps->tokens->type == TOKEN_NAME
&& strcmp(ps->tokens->data.name, "_") == 0
&& is_operator(&ps->tokens[1], ":"))
{
// case _:
ps->tokens++;
result.case_underscore = parse_body(ps);
} else {
List(AstExpression) case_objs = {0};
while(1){
Append(&case_objs, parse_expression(ps));
if (is_operator(ps->tokens, "|"))
ps->tokens++;
else if (is_operator(ps->tokens, ":"))
break;
else
fail_with_parse_error(ps->tokens, "'|' or ':'");
}
result.cases = realloc(result.cases, sizeof result.cases[0] * (result.ncases + 1));
result.cases[result.ncases++] = (AstCase){
.case_objs = case_objs.ptr,
.n_case_objs = case_objs.len,
.body = parse_body(ps),
};
}
}
ps->tokens++;
return result;
}


// reverse code golfing: https://xkcd.com/1960/
static enum AstStatementKind determine_the_kind_of_a_statement_that_starts_with_an_expression(
const Token *this_token_is_after_that_initial_expression)
Expand Down Expand Up @@ -1041,6 +1085,9 @@ static AstStatement parse_statement(ParserState *ps)
} else if (is_keyword(ps->tokens, "if")) {
result.kind = AST_STMT_IF;
result.data.ifstatement = parse_if_statement(ps);
} else if (is_keyword(ps->tokens, "match")) {
result.kind = AST_STMT_MATCH;
result.data.match = parse_match_statement(ps);
} else if (is_keyword(ps->tokens, "while")) {
ps->tokens++;
result.kind = AST_STMT_WHILE;
Expand Down
4 changes: 4 additions & 0 deletions bootstrap_compiler/print.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp)
printf("body:\n");
print_ast_body(&stmt->data.forloop.body, sub);
break;
case AST_STMT_MATCH:
printf("match (printing not implemented)\n");
// TODO: implement printing match statement, if needed for debugging
break;
case AST_STMT_BREAK:
printf("break\n");
break;
Expand Down
4 changes: 2 additions & 2 deletions bootstrap_compiler/tokenize.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ static bool is_keyword(const char *s)
"return", "if", "elif", "else", "while", "for", "pass", "break", "continue",
"True", "False", "None", "NULL", "void", "noreturn",
"and", "or", "not", "self", "as", "sizeof", "assert",
"bool", "byte", "short", "int", "long", "float", "double",
"bool", "byte", "short", "int", "long", "float", "double", "match", "case",
};
for (const char **kw = &keywords[0]; kw < &keywords[sizeof(keywords)/sizeof(keywords[0])]; kw++)
if (!strcmp(*kw, s))
Expand Down Expand Up @@ -351,7 +351,7 @@ static const char *read_operator(struct State *st)
// Longer operators are first, so that '==' does not tokenize as '=' '='
"...", "===", "!==",
"==", "!=", "->", "<=", ">=", "++", "--", "+=", "-=", "*=", "/=", "%=", "&&", "||",
".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!",
".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", "|",
NULL,
};

Expand Down
22 changes: 21 additions & 1 deletion bootstrap_compiler/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,6 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi

static void typecheck_dereferenced_pointer(Location location, const Type *t)
{
// TODO: improved error message for dereferencing void*
if (t->kind != TYPE_POINTER)
fail(location, "the dereference operator '*' is only for pointers, not for %s", t->name);
}
Expand Down Expand Up @@ -1287,6 +1286,23 @@ static void typecheck_if_statement(FileTypes *ft, const AstIfStatement *ifstmt)
typecheck_body(ft, &ifstmt->elsebody);
}

static void typecheck_match_statement(FileTypes *ft, AstMatchStatement *match_stmt)
{
const Type *mtype = typecheck_expression_not_void(ft, &match_stmt->match_obj)->type;
assert(mtype->kind == TYPE_ENUM);

for (int i = 0; i < match_stmt->ncases; i++) {
for (int k = 0; k < match_stmt->cases[i].n_case_objs; k++) {
typecheck_expression_with_implicit_cast(
ft, &match_stmt->cases[i].case_objs[k], mtype,
"case value of type FROM cannot be matched against TO"
);
}
typecheck_body(ft, &match_stmt->cases[i].body);
}
typecheck_body(ft, &match_stmt->case_underscore);
}

static void typecheck_statement(FileTypes *ft, AstStatement *stmt)
{
switch(stmt->kind) {
Expand All @@ -1310,6 +1326,10 @@ static void typecheck_statement(FileTypes *ft, AstStatement *stmt)
typecheck_statement(ft, stmt->data.forloop.incr);
break;

case AST_STMT_MATCH:
typecheck_match_statement(ft, &stmt->data.match);
break;

case AST_STMT_BREAK:
break;

Expand Down
62 changes: 62 additions & 0 deletions compiler/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ enum AstStatementKind:
Pass
Return
If
Match
WhileLoop
ForLoop
Break
Expand Down Expand Up @@ -515,6 +516,7 @@ class AstStatement:
classdef: AstClassDef
enumdef: AstEnumDef
assertion: AstAssertion
match_statement: AstMatchStatement

def print(self) -> None:
self->print_with_tree_printer(TreePrinter{})
Expand All @@ -537,6 +539,8 @@ class AstStatement:
self->if_statement.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.ForLoop:
self->for_loop.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.Match:
self->match_statement.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.WhileLoop:
printf("while loop\n")
self->while_loop.print_with_tree_printer(tp, True)
Expand Down Expand Up @@ -601,6 +605,8 @@ class AstStatement:
self->while_loop.free()
if self->kind == AstStatementKind.ForLoop:
self->for_loop.free()
if self->kind == AstStatementKind.Match:
self->match_statement.free()
if (
self->kind == AstStatementKind.DeclareLocalVar
or self->kind == AstStatementKind.GlobalVariableDeclaration
Expand Down Expand Up @@ -696,6 +702,62 @@ class AstIfStatement:
self->else_body.free()


# match match_obj:
# case ...:
# ...
# case ...:
# ...
class AstMatchStatement:
match_obj: AstExpression
cases: AstCase*
ncases: int
case_underscore: AstBody* # body of "case _" (always last), NULL if no "case _"
case_underscore_location: Location # not meaningful if case_underscore == NULL

def print_with_tree_printer(self, tp: TreePrinter) -> None:
printf("match\n")
for i = 0; i < self->ncases; i++:
self->cases[i].print_with_tree_printer(tp, i == self->ncases - 1 and self->case_underscore == NULL)

if self->case_underscore != NULL:
sub = tp.print_prefix(True)
printf("[line %d] body of case _:\n", self->case_underscore_location.lineno)
self->case_underscore->print_with_tree_printer(sub)

def free(self) -> None:
self->match_obj.free()
for i = 0; i < self->ncases; i++:
self->cases[i].free()
free(self->cases)
if self->case_underscore != NULL:
self->case_underscore->free()
free(self->case_underscore)


# case case_obj1 | case_obj2 | case_obj3:
# body
class AstCase:
case_objs: AstExpression*
n_case_objs: int
body: AstBody

def print_with_tree_printer(self, tp: TreePrinter, is_last_case: bool) -> None:
for i = 0; i < self->n_case_objs; i++:
sub = tp.print_prefix(False)
printf("case_obj: ")
self->case_objs[i].print_with_tree_printer(sub)

sub = tp.print_prefix(is_last_case)
printf("body:\n")
self->body.print_with_tree_printer(sub)

def free(self) -> None:
for i = 0; i < self->n_case_objs; i++:
self->case_objs[i].free()
free(self->case_objs)
self->body.free()


# for init; cond; incr:
# ...body...
class AstForLoop:
Expand Down
Loading
Loading