From e4cf588ee6b16d811ff63c8827a9ad878abeef5a Mon Sep 17 00:00:00 2001 From: glyh Date: Thu, 24 Oct 2024 16:38:06 +0800 Subject: [PATCH] generate better asm for int_of_float, float_of_int, truncate & abs_float --- notes.md | 14 ------- src/closureps_eval/interpreter.mbt | 21 ++++++++++ src/cps/cps_ir.mbt | 3 ++ src/cps/precps2cps.mbt | 64 +++++++++++++++++++++++------- src/externals.mbt | 3 +- src/js/clops2js.mbt | 15 +++++++ src/riscv/codegen.mbt | 24 ++++++++++- src/riscv/extern_stub.mbt | 3 ++ src/riscv/rv_asm.mbt | 6 +++ src/ssacfg/clops2ssacfg.mbt | 12 ++++++ src/ssacfg/ssa_ir.mbt | 3 ++ 11 files changed, 136 insertions(+), 32 deletions(-) delete mode 100644 notes.md diff --git a/notes.md b/notes.md deleted file mode 100644 index 0cd7603..0000000 --- a/notes.md +++ /dev/null @@ -1,14 +0,0 @@ -## Call Convention -1. We use `ra` to store the continuation address, as it's otherwise unused in our language. We do need to push it onto stack to preserve it's value when doing a native call, though. -2. We store `closure` pointer after any arguments, so we should be able to work with native functions just fine. This differs from what is being done in the book "Compiling with Continuations". - -## TODO -1. fix call convention for external functions. For example: -``` -:print_int([?26, kont_main.4.22]) -``` -Inside print_int, we should do something like this: -``` -fn_ptr.99 = kont_main.4.22.0 -fn_ptr.99([result, kont_main.4.22]) -``` diff --git a/src/closureps_eval/interpreter.mbt b/src/closureps_eval/interpreter.mbt index 4bb3101..8e72823 100644 --- a/src/closureps_eval/interpreter.mbt +++ b/src/closureps_eval/interpreter.mbt @@ -268,6 +268,27 @@ pub fn CloPSInterpreter::eval( } continue rest } + Prim(IntOfFloat, [f], bind, rest) => { + match self.eval_v!(f) { + Double(f) => self.cur_env[bind] = Int(f.to_int()) + val => @util.die("unexpected input \{val} for `int_of_float`") + } + continue rest + } + Prim(FloatOfInt, [i], bind, rest) => { + match self.eval_v!(i) { + Int(i) => self.cur_env[bind] = Double(i.to_double()) + val => @util.die("unexpected input \{val} for `float_of_int`") + } + continue rest + } + Prim(AbsFloat, [f], bind, rest) => { + match self.eval_v!(f) { + Double(f) => self.cur_env[bind] = Double(@double.abs(f)) + val => @util.die("unexpected input \{val} for `int_of_float`") + } + continue rest + } Prim(_) => @util.die("malformed prim call \{expr}") MakeArray(len, elem, kont_closure) => match (self.eval_v!(len), self.eval_v!(elem)) { diff --git a/src/cps/cps_ir.mbt b/src/cps/cps_ir.mbt index 8c424be..bea27fe 100644 --- a/src/cps/cps_ir.mbt +++ b/src/cps/cps_ir.mbt @@ -54,6 +54,9 @@ pub enum PrimOp { Math(@top.Op, @precps.Numeric) Eq Le + IntOfFloat + FloatOfInt + AbsFloat } derive(Show) pub enum Cps { diff --git a/src/cps/precps2cps.mbt b/src/cps/precps2cps.mbt index e57fd82..afd79d1 100644 --- a/src/cps/precps2cps.mbt +++ b/src/cps/precps2cps.mbt @@ -84,21 +84,6 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { self.precps2cps_list(elements, c1) } - // (a1, a2, a3, .., an, (r) -> Unit) - App(ret_ty, f, args) => { - let k_ref = self.new_tmp(Fun([ret_ty], Unit)) - let x_ref = self.new_tmp(ret_ty) - fn c1(f : Value) { - fn c2(es : @immut/list.T[Value]) { - App(f, es.iter().append(Var(k_ref)).map(fix_label_to_var).collect()) - } - - self.precps2cps_list(args, c2) - } - - Fix(k_ref, [x_ref], c(Var(x_ref)), self.precps2cps(f, c1)) - } - // WARN: Array Creation must come with an external call and thus continuation Prim(ret_ty, MakeArray, rands) => { let k_ref = self.new_tmp(Fun([ret_ty], Unit)) @@ -132,6 +117,55 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { self.precps2cps_list(rands, c1) } + // (a1, a2, a3, .., an, (r) -> Unit) + App(ret_ty, f, args) => { + fn generate_app() { + let k_ref = self.new_tmp(Fun([ret_ty], Unit)) + let x_ref = self.new_tmp(ret_ty) + fn c1(f : Value) { + fn c2(es : @immut/list.T[Value]) { + App(f, es.iter().append(Var(k_ref)).map(fix_label_to_var).collect()) + } + + self.precps2cps_list(args, c2) + } + + Fix(k_ref, [x_ref], c(Var(x_ref)), self.precps2cps(f, c1)) + } + + guard let Var(_, var) | Label(_, var) = f else { + _ => return generate_app() + } + guard var.id < 0 else { return generate_app() } + guard let Some(name) = var.name else { _ => return generate_app() } + match name { + "int_of_float" | "truncate" => { + fn c1(a : @immut/list.T[Value]) { + let tmp = self.new_tmp(ret_ty) + Prim(IntOfFloat, a.iter().collect(), tmp, c(Var(tmp))) + } + + self.precps2cps_list(args, c1) + } + "float_of_int" => { + fn c1(a : @immut/list.T[Value]) { + let tmp = self.new_tmp(ret_ty) + Prim(FloatOfInt, a.iter().collect(), tmp, c(Var(tmp))) + } + + self.precps2cps_list(args, c1) + } + "abs_float" => { + fn c1(a : @immut/list.T[Value]) { + let tmp = self.new_tmp(ret_ty) + Prim(AbsFloat, a.iter().collect(), tmp, c(Var(tmp))) + } + + self.precps2cps_list(args, c1) + } + _ => generate_app() + } + } KthTuple(ret_ty, offset, tup) => { fn c1(v : Value) { let tmp = self.new_tmp(ret_ty) diff --git a/src/externals.mbt b/src/externals.mbt index 7419902..e78f396 100644 --- a/src/externals.mbt +++ b/src/externals.mbt @@ -1,3 +1,5 @@ +// TODO: +// Maybe generate primitives for create_*_array pub let externals_list : Array[(String, Type)] = [ ("read_int", Fun([], Int)), ("print_int", Fun([Int], Unit)), @@ -9,7 +11,6 @@ pub let externals_list : Array[(String, Type)] = [ ("truncate", Fun([Double], Int)), ("floor", Fun([Double], Double)), ("abs_float", Fun([Double], Double)), - ("abs_float", Fun([Double], Double)), ("sqrt", Fun([Double], Double)), ("sin", Fun([Double], Double)), ("cos", Fun([Double], Double)), diff --git a/src/js/clops2js.mbt b/src/js/clops2js.mbt index a49c7c1..b0ae879 100644 --- a/src/js/clops2js.mbt +++ b/src/js/clops2js.mbt @@ -138,6 +138,21 @@ pub fn JsEmitter::emit_cps(self : JsEmitter, cps : @cps.Cps) -> String { "const \{emit_var(bind)} = \{lhs_emit} <= \{rhs_emit};" continue rest } + Prim(IntOfFloat, [f], bind, rest) => { + let f_emit = emit_val(f) + output += line_start + "const \{emit_var(bind)} = Math.trunc(\{f_emit});" + continue rest + } + Prim(FloatOfInt, [i], bind, rest) => { + let i_emit = emit_val(i) + output += line_start + "const \{emit_var(bind)} = \{i_emit};" + continue rest + } + Prim(AbsFloat, [f], bind, rest) => { + let f_emit = emit_val(f) + output += line_start + "const \{emit_var(bind)} = Math.abs(\{f_emit});" + continue rest + } Prim(_) as expr => @util.die("malformed prim call \{expr}") MakeArray(len, elem, kont_closure) => { let len_emit = emit_val(len) diff --git a/src/riscv/codegen.mbt b/src/riscv/codegen.mbt index fac9166..1173553 100644 --- a/src/riscv/codegen.mbt +++ b/src/riscv/codegen.mbt @@ -1,3 +1,5 @@ +// TODO: +// 1. aggressively throw free variables of current continuation closure on stored registers enum RegTy { I32 PTR64 @@ -165,13 +167,16 @@ fn CodegenBlock::insert_asm(self : CodegenBlock, asm : RvAsm) -> Unit { | FeqD(reg, _, _) | FleD(reg, _, _) | Seqz(reg, _) - | FmvXD(reg, _) | La(reg, _) | Li(reg, _) | Neg(reg, _) | Mv(reg, _) => + | FmvXD(reg, _) + | La(reg, _) | Li(reg, _) | Neg(reg, _) | Mv(reg, _) | Fcvtwd(reg, _) => self.dirtied.insert(I(reg)) FaddD(freg, _, _) | FsubD(freg, _, _) | FmulD(freg, _, _) | FdivD(freg, _, _) - | Fld(freg, _) | FmvDX(freg, _) | FnegD(freg, _) | FmvD(freg, _) => + | Fld(freg, _) + | FmvDX(freg, _) + | FnegD(freg, _) | FmvD(freg, _) | Fcvtdw(freg, _) | Fsgnjxd(freg, _, _) => self.dirtied.insert(F(freg)) _ => () } @@ -662,6 +667,21 @@ fn CodegenBlock::codegen(self : CodegenBlock) -> Unit { let reg_v = self.pull_val_i(v) self.assign_i(bind, fn { reg_bind => [Neg(reg_bind, reg_v)] }) } + Prim(bind, IntOfFloat, [f]) => { + let reg_v = self.pull_val_f(f) + self.assign_i(bind, fn { reg_bind => [Fcvtwd(reg_bind, reg_v)] }) + } + Prim(bind, FloatOfInt, [i]) => { + let reg_v = self.pull_val_i(i) + self.assign_f(bind, fn { reg_bind => [Fcvtdw(reg_bind, reg_v)] }) + } + Prim(bind, AbsFloat, [f]) => { + let reg_v = self.pull_val_f(f) + self.assign_f( + bind, + fn { reg_bind => [Fsgnjxd(reg_bind, reg_v, reg_v)] }, + ) + } // TODO: may generate more higher quality asm if idx is known at compile time Prim(bind, Get, [arr, idx]) => { let reg_idx = self.pull_val_i(idx) diff --git a/src/riscv/extern_stub.mbt b/src/riscv/extern_stub.mbt index 8ad6591..9fe8859 100644 --- a/src/riscv/extern_stub.mbt +++ b/src/riscv/extern_stub.mbt @@ -1,3 +1,6 @@ +// TODO: if S reg being occupied for free vars of continuation +// we can't use them for temporary storage here. + fn collect_externals(cfg : @ssacfg.SsaCfg) -> @hashset.T[Var] { let out = @hashset.new() fn collect_label_var(v : Var) { diff --git a/src/riscv/rv_asm.mbt b/src/riscv/rv_asm.mbt index d3bc605..9af93c0 100644 --- a/src/riscv/rv_asm.mbt +++ b/src/riscv/rv_asm.mbt @@ -67,6 +67,9 @@ pub enum RvAsm { FleD(Reg, FReg, FReg) FmvDX(FReg, Reg) FmvXD(Reg, FReg) + Fcvtdw(FReg, Reg) + Fcvtwd(Reg, FReg) + Fsgnjxd(FReg, FReg, FReg) // pseudo instructions Nop La(Reg, Label) @@ -235,6 +238,9 @@ impl Show for RvAsm with output(self, logger) { FleD(rd, rs1, rs2) => write3(logger, "fle.d", rd, rs1, rs2) FmvDX(rd, rs1) => write2(logger, "fmv.d.x", rd, rs1) FmvXD(rd, rs1) => write2(logger, "fmv.x.d", rd, rs1) + Fcvtdw(rd, rs1) => write2(logger, "fcvt.d.w", rd, rs1) + Fcvtwd(rd, rs1) => write2(logger, "fcvt.w.d", rd, rs1) + Fsgnjxd(rd, rs1, rs2) => write3(logger, "fsgnjx.d", rd, rs1, rs2) Nop => logger.write_string("nop") La(rd, label) => { logger.write_string("la ") diff --git a/src/ssacfg/clops2ssacfg.mbt b/src/ssacfg/clops2ssacfg.mbt index 1bb0c79..71c0ad9 100644 --- a/src/ssacfg/clops2ssacfg.mbt +++ b/src/ssacfg/clops2ssacfg.mbt @@ -15,6 +15,18 @@ fn SsaCfg::cps2block( cur_block.insts.push(KthTuple(bind, tup, idx)) continue rest } + Prim(IntOfFloat, args, bind, rest) => { + cur_block.insts.push(Prim(bind, IntOfFloat, args)) + continue rest + } + Prim(FloatOfInt, args, bind, rest) => { + cur_block.insts.push(Prim(bind, FloatOfInt, args)) + continue rest + } + Prim(AbsFloat, args, bind, rest) => { + cur_block.insts.push(Prim(bind, AbsFloat, args)) + continue rest + } Prim(Not, args, bind, rest) => { cur_block.insts.push(Prim(bind, Not, args)) continue rest diff --git a/src/ssacfg/ssa_ir.mbt b/src/ssacfg/ssa_ir.mbt index e3fab27..8681e22 100644 --- a/src/ssacfg/ssa_ir.mbt +++ b/src/ssacfg/ssa_ir.mbt @@ -7,6 +7,9 @@ pub enum PrimOp { Math(@typing.Op, @precps.Numeric) Eq Le + IntOfFloat + FloatOfInt + AbsFloat } derive(Show) pub enum Inst {