Skip to content

Commit

Permalink
Merge pull request #2913 from o1-labs/dw/simply-prettifying
Browse files Browse the repository at this point in the history
o1vm/riscv32im: being strict on scope and unsafe def
  • Loading branch information
dannywillems authored Dec 28, 2024
2 parents 56de67d + ff89083 commit edbc8be
Showing 1 changed file with 50 additions and 39 deletions.
89 changes: 50 additions & 39 deletions o1vm/src/interpreters/riscv32im/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1507,11 +1507,12 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
// add: x[rd] = x[rs1] + x[rs2]
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let overflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe {
let (local_rd, _overflow) =
env.add_witness(&local_rs1, &local_rs2, rd_scratch, overflow_scratch);
let local_rd = {
let overflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let (local_rd, _overflow) = unsafe {
env.add_witness(&local_rs1, &local_rs2, rd_scratch, overflow_scratch)
};
local_rd
};
env.write_register(&rd, local_rd);
Expand All @@ -1523,11 +1524,12 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* sub: x[rd] = x[rs1] - x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let underflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe {
let (local_rd, _underflow) =
env.sub_witness(&local_rs1, &local_rs2, rd_scratch, underflow_scratch);
let local_rd = {
let underflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let (local_rd, _underflow) = unsafe {
env.sub_witness(&local_rs1, &local_rs2, rd_scratch, underflow_scratch)
};
local_rd
};
env.write_register(&rd, local_rd);
Expand All @@ -1539,9 +1541,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* sll: x[rd] = x[rs1] << x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let rd_scratch = env.alloc_scratch();
env.shift_left(&local_rs1, &local_rs2, rd_scratch)
unsafe { env.shift_left(&local_rs1, &local_rs2, rd_scratch) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1552,9 +1554,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* slt: x[rd] = (x[rs1] < x[rs2]) ? 1 : 0 */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let rd_scratch = env.alloc_scratch();
env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch)
unsafe { env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1565,9 +1567,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* sltu: x[rd] = (x[rs1] < (u)x[rs2]) ? 1 : 0 */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.test_less_than(&local_rs1, &local_rs2, pos)
unsafe { env.test_less_than(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1578,9 +1580,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* xor: x[rd] = x[rs1] ^ x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.xor_witness(&local_rs1, &local_rs2, pos)
unsafe { env.xor_witness(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1591,9 +1593,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* srl: x[rd] = x[rs1] >> x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.shift_right(&local_rs1, &local_rs2, pos)
unsafe { env.shift_right(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1604,9 +1606,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* sra: x[rd] = x[rs1] >> x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.shift_right_arithmetic(&local_rs1, &local_rs2, pos)
unsafe { env.shift_right_arithmetic(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1617,9 +1619,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* or: x[rd] = x[rs1] | x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.or_witness(&local_rs1, &local_rs2, pos)
unsafe { env.or_witness(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand All @@ -1630,9 +1632,9 @@ pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction)
/* and: x[rd] = x[rs1] & x[rs2] */
let local_rs1 = env.read_register(&rs1);
let local_rs2 = env.read_register(&rs2);
let local_rd = unsafe {
let local_rd = {
let pos = env.alloc_scratch();
env.and_witness(&local_rs1, &local_rs2, pos)
unsafe { env.and_witness(&local_rs1, &local_rs2, pos) }
};
env.write_register(&rd, local_rd);

Expand Down Expand Up @@ -1895,11 +1897,12 @@ pub fn interpret_itype<Env: InterpreterEnv>(env: &mut Env, instr: IInstruction)
// addi: x[rd] = x[rs1] + sext(immediate)
let local_rs1 = env.read_register(&(rs1.clone()));
let local_imm = env.sign_extend(&imm, 12);
let overflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe {
let (local_rd, _overflow) =
env.add_witness(&local_rs1, &local_imm, rd_scratch, overflow_scratch);
let local_rd = {
let overflow_scratch = env.alloc_scratch();
let rd_scratch = env.alloc_scratch();
let (local_rd, _overflow) = unsafe {
env.add_witness(&local_rs1, &local_imm, rd_scratch, overflow_scratch)
};
local_rd
};
env.write_register(&rd, local_rd);
Expand All @@ -1910,8 +1913,10 @@ pub fn interpret_itype<Env: InterpreterEnv>(env: &mut Env, instr: IInstruction)
// xori: x[rd] = x[rs1] ^ sext(immediate)
let local_rs1 = env.read_register(&rs1);
let local_imm = env.sign_extend(&imm, 12);
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe { env.xor_witness(&local_rs1, &local_imm, rd_scratch) };
let local_rd = {
let rd_scratch = env.alloc_scratch();
unsafe { env.xor_witness(&local_rs1, &local_imm, rd_scratch) }
};
env.write_register(&rd, local_rd);
env.set_instruction_pointer(next_instruction_pointer.clone());
env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
Expand All @@ -1920,8 +1925,10 @@ pub fn interpret_itype<Env: InterpreterEnv>(env: &mut Env, instr: IInstruction)
// ori: x[rd] = x[rs1] | sext(immediate)
let local_rs1 = env.read_register(&rs1);
let local_imm = env.sign_extend(&imm, 12);
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe { env.or_witness(&local_rs1, &local_imm, rd_scratch) };
let local_rd = {
let rd_scratch = env.alloc_scratch();
unsafe { env.or_witness(&local_rs1, &local_imm, rd_scratch) }
};
env.write_register(&rd, local_rd);
env.set_instruction_pointer(next_instruction_pointer.clone());
env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
Expand All @@ -1930,8 +1937,10 @@ pub fn interpret_itype<Env: InterpreterEnv>(env: &mut Env, instr: IInstruction)
// andi: x[rd] = x[rs1] & sext(immediate)
let local_rs1 = env.read_register(&rs1);
let local_imm = env.sign_extend(&imm, 12);
let rd_scratch = env.alloc_scratch();
let local_rd = unsafe { env.and_witness(&local_rs1, &local_imm, rd_scratch) };
let local_rd = {
let rd_scratch = env.alloc_scratch();
unsafe { env.and_witness(&local_rs1, &local_imm, rd_scratch) }
};
env.write_register(&rd, local_rd);
env.set_instruction_pointer(next_instruction_pointer.clone());
env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
Expand Down Expand Up @@ -2438,8 +2447,10 @@ pub fn interpret_utype<Env: InterpreterEnv>(env: &mut Env, instr: UInstruction)
UInstruction::LoadUpperImmediate => {
// lui: x[rd] = sext(immediate[31:12] << 12)
let local_imm = {
let pos = env.alloc_scratch();
let shifted_imm = unsafe { env.shift_left(&imm, &Env::constant(12), pos) };
let shifted_imm = {
let pos = env.alloc_scratch();
unsafe { env.shift_left(&imm, &Env::constant(12), pos) }
};
env.sign_extend(&shifted_imm, 32)
};
env.write_register(&rd, local_imm);
Expand Down

0 comments on commit edbc8be

Please sign in to comment.